xref: /aosp_15_r20/external/pytorch/test/profiler/test_profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: profiler"]
2
3import collections
4import gc
5import json
6import mmap
7import os
8import pickle
9import random
10import re
11import struct
12import subprocess
13import sys
14import tempfile
15import threading
16import time
17import unittest
18from dataclasses import dataclass, field
19from typing import List, Optional
20from unittest.mock import patch
21
22import expecttest
23
24import torch
25import torch.nn as nn
26import torch.optim
27import torch.utils.data
28from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall
29from torch.autograd.profiler import KinetoStepTracker, profile as _profile
30from torch.autograd.profiler_legacy import profile as _profile_legacy
31from torch.profiler import (
32    _utils,
33    DeviceType,
34    kineto_available,
35    profile,
36    ProfilerAction,
37    ProfilerActivity,
38    record_function,
39    supported_activities,
40)
41from torch.profiler._pattern_matcher import (
42    Conv2dBiasFollowedByBatchNorm2dPattern,
43    ExtraCUDACopyPattern,
44    ForLoopIndexingPattern,
45    FP32MatMulPattern,
46    GradNotSetToNonePattern,
47    MatMulDimInFP16Pattern,
48    NamePattern,
49    OptimizerSingleTensorPattern,
50    Pattern,
51    report_all_anti_patterns,
52    SynchronizedDataLoaderPattern,
53)
54from torch.testing._internal.common_cuda import TEST_MULTIGPU
55from torch.testing._internal.common_device_type import skipCUDAVersionIn
56from torch.testing._internal.common_utils import (
57    instantiate_parametrized_tests,
58    IS_ARM64,
59    IS_JETSON,
60    IS_LINUX,
61    IS_WINDOWS,
62    parametrize,
63    run_tests,
64    serialTest,
65    skipIfTorchDynamo,
66    TemporaryDirectoryName,
67    TemporaryFileName,
68    TEST_WITH_ASAN,
69    TEST_WITH_CROSSREF,
70    TEST_WITH_ROCM,
71    TestCase,
72)
73
74
75# if tqdm is not shutdown properly, it will leave the monitor thread alive.
76# This causes an issue in the multithreading test because we check all events
77# in that test with their tids. The events that correspond to these lingering
78# threads all have TID of (uint64_t)(-1) which is invalid.
79# The work around is turnning off monitoring thread when tqdm is loaded.
80# Since these are unit tests, it is safe to turn off monitor thread.
81try:
82    import tqdm
83
84    tqdm.tqdm.monitor_interval = 0
85except ImportError:
86    pass
87
88try:
89    import psutil
90
91    HAS_PSUTIL = True
92except ModuleNotFoundError:
93    HAS_PSUTIL = False
94    psutil = None
95
96
97@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
98@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
99@unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows")
100@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
101class TestProfilerCUDA(TestCase):
102    @skipCUDAVersionIn([(11, 5)])  # https://github.com/pytorch/pytorch/issues/69023
103    def test_mem_leak(self):
104        """Checks that there's no memory leak when using profiler with CUDA"""
105        t = torch.rand(1, 1).cuda()
106        p = psutil.Process()
107        last_rss = collections.deque(maxlen=5)
108        for outer_idx in range(10):
109            with _profile(use_cuda=True):
110                for _ in range(1024):
111                    t = torch.mm(t, t)
112
113            gc.collect()
114            torch.cuda.empty_cache()
115            last_rss.append(p.memory_info().rss)
116
117        # with CUDA events leaking the increase in memory was ~7 MB between
118        # profiler invocations above
119        is_increasing = all(
120            last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss))
121        )
122        max_diff = -1
123        for idx in range(1, len(last_rss)):
124            max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1])
125        self.assertTrue(
126            not (is_increasing and max_diff > 100 * 1024),
127            msg=f"memory usage is increasing, {str(last_rss)}",
128        )
129
130    def test_custom_module_input_op_ids(self):
131        class MyFunc(torch.autograd.Function):
132            @staticmethod
133            def forward(ctx, x):
134                ctx.save_for_backward(x)
135                return x
136
137            @staticmethod
138            def backward(ctx, gO):
139                (x,) = ctx.saved_tensors
140                return x
141
142        def custom_layer(input_ten):
143            return MyFunc.apply(input_ten)
144
145        # Only testing that emit_nvtx runs when
146        # record_shapes option is enabled.
147        with torch.autograd.profiler.emit_nvtx(record_shapes=True) as prof:
148            x = torch.randn(10, 10, requires_grad=True)
149            y = torch.randn(10, 10, requires_grad=True)
150            z = x + y
151            s = custom_layer(z)
152            q = s.sum()
153            q.backward()
154
155    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
156    def test_cudagraph_profiling_workaround(self):
157        import subprocess
158
159        # repro taken from #75504
160        # Launch in a separate process to catch hanging/illegal memory errors
161        # and to make sure CUPTI isn't already initialized.
162        p = subprocess.check_call(
163            [
164                sys.executable,
165                "-c",
166                """
167import os
168import torch
169from torch.profiler import ProfilerActivity, profile
170
171def add_one(in_: torch.Tensor):
172    return in_ + 1
173
174sample_arg = torch.zeros(10, device="cuda").requires_grad_(True)
175
176# add this before cuda graphs are created
177torch.profiler._utils._init_for_cuda_graphs()
178
179add_one_graphed = torch.cuda.graphs.make_graphed_callables(add_one, sample_args=(sample_arg,))
180zeros = torch.zeros(10, device="cuda")
181out = add_one_graphed(zeros)
182assert out[0] == 1
183
184with profile(activities=[ProfilerActivity.CPU]):
185    add_one_graphed(zeros)
186
187with profile(activities=[ProfilerActivity.CUDA]):
188    add_one_graphed(zeros)
189""",
190            ],
191            universal_newlines=True,
192            timeout=60,
193        )
194
195        # ^ this will throw an exception if the script fails.
196
197
198@unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required")
199class TestProfilerITT(TestCase):
200    def test_custom_module_input_op_ids(self):
201        class MyFunc(torch.autograd.Function):
202            @staticmethod
203            def forward(ctx, x):
204                ctx.save_for_backward(x)
205                return x
206
207            @staticmethod
208            def backward(ctx, gO):
209                (x,) = ctx.saved_tensors
210                return x
211
212        def custom_layer(input_ten):
213            return MyFunc.apply(input_ten)
214
215        # Only testing that emit_itt runs when
216        # record_shapes option is enabled.
217        with torch.autograd.profiler.emit_itt(record_shapes=True) as prof:
218            x = torch.randn(10, 10, requires_grad=True)
219            y = torch.randn(10, 10, requires_grad=True)
220            z = x + y
221            s = custom_layer(z)
222            q = s.sum()
223            q.backward()
224
225
226@instantiate_parametrized_tests
227class TestProfiler(TestCase):
228    @unittest.skipIf(
229        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
230    )
231    def test_source(self):
232        """Checks that source code attribution works for eager, TS and autograd mode"""
233        # avoid automatic inlining
234        prev_opt = torch._C._get_graph_executor_optimize()
235        torch._C._set_graph_executor_optimize(False)
236
237        @torch.jit.script
238        def ts_method_2(x, y):
239            return torch.matmul(x, y)
240
241        @torch.jit.script
242        def ts_method_1(x, y, z):
243            a = x + z
244            w = ts_method_2(x, y) + a
245            return w.sum()
246
247        class DummyModule(nn.Module):
248            def __init__(self) -> None:
249                super().__init__()
250                self.conv = torch.nn.Conv2d(
251                    3, 2, kernel_size=1, stride=2, padding=3, bias=False
252                )
253
254            def forward(self, x):
255                return self.conv(x)
256
257        mod = DummyModule()
258
259        def call_module(x):
260            return mod(x)
261
262        with _profile(
263            with_stack=True,
264            use_kineto=kineto_available(),
265            experimental_config=_ExperimentalConfig(verbose=True),
266        ) as p:
267            x = torch.randn(10, 10, requires_grad=True)
268            y = torch.randn(10, 10, requires_grad=True)
269            z = x + y
270            w = ts_method_1(x, y, z)
271            v = 2 * w
272            v.backward()
273            a = torch.randn(2, 3, 2, 2, requires_grad=True)
274            b = call_module(a)
275            c = b.sum()
276            c.backward()
277
278        for e in p.function_events:
279            if "aten::add" in e.name or "AddBackward" in e.name:
280                self.assertTrue(any("test_profiler" in entry for entry in e.stack))
281                self.assertTrue(
282                    any(
283                        (
284                            "test_source" in entry
285                            or "ts_method_1" in entry
286                            or "ts_method_2" in entry
287                        )
288                        for entry in e.stack
289                    )
290                )
291
292        # TODO: https://github.com/pytorch/kineto/issues/617
293        if kineto_available() and not IS_WINDOWS:
294            with TemporaryFileName(mode="w+") as fname:
295                p.export_chrome_trace(fname)
296                with open(fname) as f:
297                    events = json.load(f)["traceEvents"]
298
299                def extract(pattern: str):
300                    matches = [e for e in events if re.search(pattern, e["name"])]
301                    self.assertEqual(
302                        len(matches), 1, repr([e["name"] for e in matches])
303                    )
304                    return matches[0]
305
306                module_event = extract(r"DummyModule_0")
307                wrapper_event = extract(r"call_module")
308                self.assertEqual(
309                    module_event["args"]["Python parent id"],
310                    wrapper_event["args"]["Python id"],
311                )
312
313        torch._C._set_graph_executor_optimize(prev_opt)
314
315    @parametrize(
316        "name,thread_spec",
317        {
318            "basic": ((False, False),),
319            "multiple_preexisting": ((False, False),) * 2,
320            "open_in_scope": ((True, False),),
321            "close_in_scope": ((False, True),),
322            "complex": (
323                # Large number of background threads
324                (False, False),
325                (False, False),
326                (False, False),
327                (False, False),
328                # some of which finish during profiling
329                (False, True),
330                (False, True),
331                # And the profiled section is also multithreaded
332                (True, False),
333                (True, True),
334            ),
335        }.items(),
336        name_fn=lambda name, thread_spec: name,
337    )
338    @serialTest()
339    @parametrize("work_in_main_thread", [True, False])
340    def test_source_multithreaded(self, name, thread_spec, work_in_main_thread):
341        """Test various threading configurations.
342
343        `thread_spec` is a Tuple[Tuple[bool, bool], ...] where each pair is a
344        thread. The first bool indicates if the thread should be started under
345        the profiler context and the second is if it should be joined under the
346        profiler context.
347        """
348
349        timeout = 15
350        num_threads = len(thread_spec) + 1  # Main thread
351        start_barrier = threading.Barrier(num_threads, timeout=timeout)
352        end_barrier = threading.Barrier(num_threads, timeout=timeout)
353
354        class Task(threading.Thread):
355            def __init__(self) -> None:
356                self._end_gate = threading.Event()
357                super().__init__(daemon=True)
358                self.start()
359                self.finished = False
360
361            def run(self):
362                self._run(self._end_gate)
363
364            def release(self):
365                self._end_gate.set()
366
367            @staticmethod
368            def _run(end_gate=None):
369                def known_preexisting_function():
370                    start_barrier.wait()
371
372                # Fixed point that we can use to test capture of functions
373                # which are already running when profiling is enabled.
374                known_preexisting_function()
375
376                model = torch.nn.Sequential(
377                    torch.nn.Linear(10, 10),
378                    torch.nn.ReLU(),
379                )
380
381                def invoked_during_run():
382                    pass
383
384                invoked_during_run()
385
386                _ = model(torch.rand(4, 10))
387                end_barrier.wait()
388
389                if end_gate is not None:
390                    end_gate.wait(timeout=timeout)
391
392        threads = {}
393
394        def add_threads(context: bool):
395            for idx, (start_under_profiler, _) in enumerate(thread_spec):
396                if start_under_profiler == context:
397                    assert idx not in threads
398                    threads[idx] = Task()
399
400        def join_threads(context: bool):
401            for idx, (_, end_under_profiler) in enumerate(thread_spec):
402                if end_under_profiler == context:
403                    threads[idx].release()
404
405            for idx, (_, end_under_profiler) in enumerate(thread_spec):
406                t = threads[idx]
407                if end_under_profiler == context:
408                    t.join(timeout=timeout)
409
410        try:
411            add_threads(False)
412            with torch.profiler.profile(with_stack=True) as prof:
413                # Threads added while the profiler are running will not be observed
414                # since there is no way to hook into Python's thread start call to
415                # register the observer. These are here purely to verify safety.
416                add_threads(True)
417
418                if work_in_main_thread:
419                    Task._run()
420                else:
421                    start_barrier.wait()
422                    end_barrier.wait()
423
424                join_threads(True)
425            join_threads(False)
426
427        finally:
428            # It is very important that we clean up everything because the
429            # Python tracer will detect ALL active threads. (Even orphans from
430            # prior failed tests.) If we don't clean up properly we can
431            # contaminate subsequent tests.
432            start_barrier.abort()
433            end_barrier.abort()
434            for t in threads.values():
435                t.release()
436
437            for t in threads.values():
438                t.join(timeout=timeout)
439
440            for t in threads.values():
441                self.assertFalse(t.is_alive())
442
443        roots = prof.profiler.kineto_results.experimental_event_tree()
444        nodes = [
445            node
446            for node in _utils.traverse_dfs(roots)
447            if isinstance(node.extra_fields, _ExtraFields_PyCall)
448        ]
449        tid_counts = collections.Counter([node.start_tid for node in nodes])
450
451        prior_threads = sum(
452            not start_under_profiler for start_under_profiler, _ in thread_spec
453        )
454        expected_threads = prior_threads + 1
455        self.assertEqual(
456            len(tid_counts), expected_threads, f"{expected_threads}, {tid_counts}"
457        )
458        self.assertEqual(len(nodes), sum(tid_counts.values()))
459
460        # Profiler uses uint64_t max as a placeholder until TID can be determined.
461        no_tid = 2**64 - 1
462        self.assertFalse(no_tid in tid_counts)
463
464        worker_threads = prior_threads + (1 if work_in_main_thread else 0)
465
466        observed_preexisting = [
467            node.start_tid
468            for node in nodes
469            if "known_preexisting_function" in node.name
470        ]
471        self.assertEqual(len(observed_preexisting), worker_threads)
472        self.assertEqual(len(observed_preexisting), len(set(observed_preexisting)))
473
474        observed_during_run = [
475            node.start_tid for node in nodes if "invoked_during_run" in node.name
476        ]
477        self.assertEqual(len(observed_during_run), worker_threads)
478        self.assertEqual(len(observed_during_run), len(set(observed_during_run)))
479
480    def payload(self, use_cuda=False):
481        x = torch.randn(10, 10)
482        if use_cuda:
483            x = x.cuda()
484        y = torch.randn(10, 10)
485        if use_cuda:
486            y = y.cuda()
487        z = torch.mm(x, y)
488        z = z + y
489        if use_cuda:
490            z = z.cpu()
491
492    def _check_stats(self, profiler_stats):
493        self.assertGreater(profiler_stats.profiling_window_duration_sec, 0)
494        self.assertGreater(profiler_stats.number_of_events, 0)
495        self.assertGreater(profiler_stats.profiler_prepare_call_duration_us, 0)
496        self.assertGreater(profiler_stats.profiler_enable_call_duration_us, 0)
497        self.assertGreater(profiler_stats.profiler_disable_call_duration_us, 0)
498        self.assertGreater(profiler_stats.parse_kineto_call_duration_us, 0)
499        self.assertGreater(
500            profiler_stats.function_events_build_tree_call_duration_us, 0
501        )
502
503    @unittest.skipIf(not kineto_available(), "Kineto is required")
504    def test_kineto(self):
505        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
506        with _profile(use_cuda=use_cuda, use_kineto=True):
507            self.payload(use_cuda=use_cuda)
508
509        # rerun to avoid initial start overhead
510        with _profile(use_cuda=use_cuda, use_kineto=True) as p:
511            self.payload(use_cuda=use_cuda)
512
513        self.assertTrue("aten::mm" in str(p))
514
515        output = p.key_averages().table(
516            sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
517            row_limit=-1,
518        )
519        # print(output)
520        found_gemm = False
521        found_memcpy = False
522        found_mm = False
523        for e in p.function_events:
524            if "aten::mm" in e.name:
525                found_mm = True
526            if "gemm" in e.name.lower() or "Cijk" in e.name:
527                found_gemm = True
528            if "memcpy" in e.name.lower():
529                found_memcpy = True
530        if use_cuda:
531            self.assertTrue(found_gemm)
532            self.assertTrue(found_memcpy)
533        else:
534            self.assertTrue(found_mm)
535        self._check_stats(p._stats)
536        # p.export_chrome_trace("/tmp/test_trace.json")
537
538    @unittest.skipIf(not kineto_available(), "Kineto is required")
539    @unittest.skipIf(not TEST_MULTIGPU, "Multiple GPUs needed")
540    @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
541    def test_kineto_multigpu(self):
542        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
543            for gpu_id in [0, 1]:
544                x = torch.randn(10, 10).cuda(gpu_id)
545                y = torch.randn(10, 10).cuda(gpu_id)
546                z = x.matmul(y)
547
548        found_gemm_0 = False
549        found_gemm_1 = False
550        found_cuda = False
551        for evt in prof.events():
552            if "gemm" in evt.name.lower() and evt.device_type == DeviceType.CUDA:
553                if evt.device_index == 0:
554                    found_gemm_0 = True
555                elif evt.device_index == 1:
556                    found_gemm_1 = True
557            if "cuda" in evt.name.lower() and evt.device_type == DeviceType.CPU:
558                found_cuda = True
559
560        self.assertTrue(found_gemm_0)
561        self.assertTrue(found_gemm_1)
562        self.assertTrue(found_cuda)
563        self._check_stats(prof._stats())
564
565    def test_memory_profiler(self):
566        def run_profiler(tensor_creation_fn):
567            # collecting allocs / deallocs
568            with _profile(
569                profile_memory=True,
570                record_shapes=True,
571                use_kineto=kineto_available(),
572            ) as prof:
573                x = None
574                with record_function("test_user_scope_alloc"):
575                    x = tensor_creation_fn()
576                with record_function("test_user_scope_dealloc"):
577                    del x
578            return prof.key_averages(group_by_input_shape=True)
579
580        def check_metrics(stats, metric, allocs=None, deallocs=None):
581            stat_metrics = {}
582            # print(stats)
583            for stat in stats:
584                stat_metrics[stat.key] = getattr(stat, metric)
585            # print(stat_metrics)
586            if allocs is not None:
587                for alloc_fn in allocs:
588                    self.assertTrue(alloc_fn in stat_metrics)
589                    self.assertGreater(
590                        stat_metrics[alloc_fn], 0, f"alloc_fn = {alloc_fn}"
591                    )
592            if deallocs is not None:
593                for dealloc_fn in deallocs:
594                    self.assertTrue(dealloc_fn in stat_metrics)
595                    self.assertLess(
596                        stat_metrics[dealloc_fn], 0, f"alloc_fn = {dealloc_fn}"
597                    )
598
599        def create_cpu_tensor():
600            return torch.rand(10, 10)
601
602        def create_cuda_tensor():
603            return torch.rand(10, 10).cuda()
604
605        def create_mkldnn_tensor():
606            return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()
607
608        stats = run_profiler(create_cpu_tensor)
609        check_metrics(
610            stats,
611            "cpu_memory_usage",
612            allocs=[
613                "aten::empty",
614                "aten::rand",
615                "test_user_scope_alloc",
616            ],
617            deallocs=[
618                "test_user_scope_dealloc",
619            ],
620        )
621
622        if kineto_available():
623            with TemporaryFileName(mode="w+") as fname:
624                with profile(profile_memory=True) as prof:
625                    x = None
626                    with record_function("test_user_scope_alloc"):
627                        x = create_cpu_tensor()
628                    with record_function("test_user_scope_dealloc"):
629                        del x
630                prof.export_chrome_trace(fname)
631                with open(fname) as f:
632                    trace = json.load(f)
633                    assert "traceEvents" in trace
634                    events = trace["traceEvents"]
635                    found_memory_events = False
636                    for evt in events:
637                        assert "name" in evt
638                        if evt["name"] == "[memory]":
639                            found_memory_events = True
640                            assert "args" in evt
641                            assert "Addr" in evt["args"]
642                            assert "Device Type" in evt["args"]
643                            assert "Device Id" in evt["args"]
644                            assert "Bytes" in evt["args"]
645
646                            # Memory should be an instantaneous event.
647                            assert "dur" not in evt["args"]
648                            assert "cat" not in evt["args"]
649                    assert found_memory_events
650
651        if torch.cuda.is_available():
652            create_cuda_tensor()
653            stats = run_profiler(create_cuda_tensor)
654            check_metrics(
655                stats,
656                "device_memory_usage",
657                allocs=[
658                    "test_user_scope_alloc",
659                    "aten::to",
660                    "aten::empty_strided",
661                ],
662                deallocs=[
663                    "test_user_scope_dealloc",
664                ],
665            )
666            check_metrics(
667                stats,
668                "cpu_memory_usage",
669                allocs=[
670                    "aten::rand",
671                    "aten::empty",
672                ],
673            )
674
675        if torch.backends.mkldnn.is_available():
676            create_mkldnn_tensor()
677            stats = run_profiler(create_mkldnn_tensor)
678            check_metrics(
679                stats,
680                "cpu_memory_usage",
681                allocs=[
682                    "test_user_scope_alloc",
683                    "aten::rand",
684                    "aten::empty",
685                    "aten::to_mkldnn",
686                ],
687                deallocs=[
688                    "test_user_scope_dealloc",
689                ],
690            )
691
692        # check top-level memory events
693        with _profile(profile_memory=True, use_kineto=kineto_available()) as prof:
694            x = torch.rand(10, 10)
695            del x
696            if torch.cuda.is_available():
697                y = torch.rand(10, 10).cuda()
698                del y
699            gc.collect()
700        stats = prof.key_averages(group_by_input_shape=True)
701        check_metrics(
702            stats,
703            "cpu_memory_usage",
704            allocs=["aten::rand", "aten::empty"],
705            deallocs=["[memory]"],
706        )
707        if torch.cuda.is_available():
708            check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])
709
710    @unittest.skipIf(
711        IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared"
712    )
713    def test_oom_tracing(self):
714        def run_profiler(tensor_creation_fn):
715            with _profile(profile_memory=True, record_shapes=True) as prof:
716                with self.assertRaisesRegex(RuntimeError, ".*[tT]ried to allocate.*"):
717                    x = tensor_creation_fn()
718                return prof
719
720        def create_cuda_tensor_oom():
721            device = torch.device("cuda:0")
722            return torch.empty(
723                1024, 1024, 1024, 1024, dtype=torch.float32, device=device
724            )
725
726        def check_trace(fname):
727            prof.export_chrome_trace(fname)
728            with open(fname) as f:
729                trace = json.load(f)
730                self.assertTrue("traceEvents" in trace)
731                events = trace["traceEvents"]
732                found_out_of_memory_events = False
733                for evt in events:
734                    self.assertTrue("name" in evt)
735                    if evt["name"] == "[OutOfMemory]":
736                        found_out_of_memory_events = True
737                        self.assertTrue("args" in evt)
738                        self.assertTrue("Device Type" in evt["args"])
739                        self.assertTrue("Device Id" in evt["args"])
740                        self.assertTrue("Bytes" in evt["args"])
741
742                        # Memory should be an instantaneous event.
743                        self.assertTrue("dur" not in evt["args"])
744                        self.assertTrue("cat" not in evt["args"])
745                self.assertTrue(found_out_of_memory_events)
746
747        if torch.cuda.is_available():
748            with TemporaryFileName(mode="w+") as fname:
749                prof = run_profiler(create_cuda_tensor_oom)
750                check_trace(fname)
751
752    @unittest.skipIf(not kineto_available(), "Kineto is required")
753    def test_module_hierarchy(self):
754        class A(nn.Module):
755            def my_new_method(self, x):
756                return x * 3
757
758            def forward_impl_(self, x, y):
759                return self.my_new_method(x) + y
760
761            def forward(self, x, y):
762                y = y - 2
763                return self.forward_impl_(x, y)
764
765        class B(nn.Module):
766            def forward(self, x):
767                return x + 2
768
769        class C(nn.Module):
770            def __init__(self) -> None:
771                super().__init__()
772                self.A0 = A()
773                self.B0 = B()
774
775            def call_b(self, x):
776                return self.B0.forward(x)
777
778            def forward(self, x, y):
779                return self.A0.forward(x, y) + self.call_b(x)
780
781        model = C()
782        model = torch.jit.script(model)
783        input_a = torch.rand(128, 128)
784        input_b = torch.rand(128, 128)
785        op_to_module_hierarchy = {}
786        op_to_module_hierarchy["aten::sub"] = ["TOP(C)::forward.A0(A)::forward."]
787        op_to_module_hierarchy["aten::mul"] = [
788            "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method."
789        ]
790        op_to_module_hierarchy["aten::add"] = [
791            "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.",
792            "TOP(C)::forward.SELF(C)::call_b.B0(B)::forward.",
793            "TOP(C)::forward.",
794        ]
795        with TemporaryFileName(mode="w+") as fname:
796            with profile(
797                activities=[torch.profiler.ProfilerActivity.CPU],
798                with_modules=True,
799            ) as prof:
800                model(input_a, input_b)
801            prof.export_chrome_trace(fname)
802            with open(fname) as f:
803                trace = json.load(f)
804                assert "traceEvents" in trace
805                events = trace["traceEvents"]
806                found_memory_events = False
807                for evt in events:
808                    assert "name" in evt
809                    if "args" in evt:
810                        op_name = evt["name"]
811                        if "Module Hierarchy" in evt["args"]:
812                            hierarchy = evt["args"]["Module Hierarchy"]
813                            if op_name in op_to_module_hierarchy:
814                                assert hierarchy in op_to_module_hierarchy[op_name]
815
816    def test_high_level_trace(self):
817        """Checks that python side high level events are recorded."""
818
819        class RepeatedDataset(torch.utils.data.Dataset):
820            def __init__(self, N, D_in, D_out):
821                self.N = N
822                self.x = torch.randn(N, D_in)
823                self.y = torch.randn(N, D_out)
824
825            def __len__(self):
826                return self.N
827
828            def __getitem__(self, idx):
829                return self.x, self.y
830
831        class TwoLayerNet(torch.nn.Module):
832            def __init__(self, D_in, H, D_out):
833                super().__init__()
834                self.linear1 = torch.nn.Linear(D_in, H)
835                self.linear2 = torch.nn.Linear(H, D_out)
836
837            def forward(self, x):
838                h_relu = self.linear1(x).clamp(min=0)
839                y_pred = self.linear2(h_relu)
840                return y_pred
841
842        class CustomSGD(torch.optim.SGD):
843            def __init__(self, *args, **kwargs):
844                super().__init__(*args, **kwargs)
845
846        def train():
847            for _, data in enumerate(dataloader):
848                x, y = data[0], data[1]
849                y_pred = model(x)
850                loss = criterion(y_pred, y)
851                optimizer.zero_grad()
852                loss.backward()
853                optimizer.step()
854
855        N, D_in, H, D_out = 8, 10, 5, 2
856        model = TwoLayerNet(D_in, H, D_out)
857        criterion = torch.nn.MSELoss(reduction="sum")
858        optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
859        ds = RepeatedDataset(N, D_in, D_out)
860        dataloader = torch.utils.data.DataLoader(ds, batch_size=1)
861
862        try:
863            train()
864        except Exception:
865            self.assertTrue(False, "Expected no exception without profiling.")
866
867        # Create multiple instances, expect each func is hooked only one time.
868        # Nested wrappers(repeated patching) will make following test fail.
869        optimizer_duplicate = torch.optim.SGD(model.parameters(), lr=1e-4)
870        dataloader_duplicate = torch.utils.data.DataLoader(ds, batch_size=1)
871
872        def judge(expected_event_count, prof):
873            actual_event_count = {}
874            for e in prof.function_events:
875                if "#" in e.name:
876                    key = e.name
877                    if key in expected_event_count.keys():
878                        actual_event_count[key] = (
879                            actual_event_count.setdefault(key, 0) + 1
880                        )
881            for key, count in expected_event_count.items():
882                self.assertTrue(
883                    (key in actual_event_count.keys())
884                    and (count == actual_event_count[key])
885                )
886
887        with _profile(use_kineto=kineto_available()) as prof:
888            train()
889        expected_event_count = {
890            # "+1" because the final iteration will enter __next__ but skip the loop body.
891            "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
892            "Optimizer.step#SGD.step": N,
893            "Optimizer.zero_grad#SGD.zero_grad": N,
894        }
895        judge(expected_event_count, prof)
896
897        # Test on pickle/unpickle. Expect to work in multi-processing.
898        optimizer = pickle.loads(pickle.dumps(optimizer))
899        with _profile(use_kineto=kineto_available()) as prof:
900            train()
901        judge(expected_event_count, prof)
902
903        # Test on customized optimizer.
904        optimizer = CustomSGD(model.parameters(), lr=1e-4)
905        with _profile(use_kineto=kineto_available()) as prof:
906            train()
907        expected_event_count = {
908            "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
909            "Optimizer.step#CustomSGD.step": N,
910            "Optimizer.zero_grad#CustomSGD.zero_grad": N,
911        }
912        judge(expected_event_count, prof)
913
914    def test_flops(self):
915        model = torch.nn.Sequential(
916            nn.Conv2d(16, 33, 18),
917            nn.ReLU(),
918            nn.Linear(243, 243),
919            nn.ReLU(),
920        )
921        inputs = torch.randn(40, 16, 18, 260)
922        nested_tensor = torch.nested.nested_tensor(
923            [torch.randn((2, 5)), torch.randn((3, 5))], layout=torch.jagged
924        )
925        with _profile(
926            record_shapes=True, with_flops=True, use_kineto=kineto_available()
927        ) as prof:
928            model(inputs)
929            # test that nested tensor won't cause exception during flop compute
930            nested_tensor = nested_tensor + nested_tensor
931        profiler_output = prof.key_averages(group_by_input_shape=True).table(
932            sort_by="cpu_time_total", row_limit=10
933        )
934        self.assertIn("Total MFLOPs", profiler_output)
935        if not (kineto_available() and torch.cuda.is_available()):
936            return
937
938        with profile(
939            activities=[
940                torch.profiler.ProfilerActivity.CPU,
941                torch.profiler.ProfilerActivity.CUDA,
942            ],
943            record_shapes=True,
944            with_flops=True,
945        ) as kineto_profiler:
946            model(inputs)
947        profiler_output = kineto_profiler.key_averages().table(
948            sort_by="self_cuda_time_total", row_limit=-1
949        )
950        self.assertIn("Total MFLOPs", profiler_output)
951
952    def test_kineto_profiler_api(self):
953        called_num = [0]
954
955        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
956        with profile(activities=supported_activities()):
957            self.payload(use_cuda=use_cuda)
958
959        def trace_handler(p):
960            output = p.key_averages().table(
961                sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
962                row_limit=-1,
963            )
964            # print(output)
965            # p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
966            called_num[0] += 1
967
968        initial_step = KinetoStepTracker.current_step()
969
970        with profile(
971            activities=supported_activities(),
972            schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
973            on_trace_ready=trace_handler,
974        ) as p:
975            for idx in range(8):
976                self.payload(use_cuda=use_cuda)
977                p.step()
978
979        self.assertEqual(called_num[0], 2)
980        self.assertEqual(KinetoStepTracker.current_step(), initial_step + 8)
981
982        # case without schedule
983        with profile(activities=supported_activities()) as p:
984            self.payload(use_cuda=use_cuda)
985            self.payload(use_cuda=use_cuda)
986        output = p.key_averages().table(
987            sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
988            row_limit=-1,
989        )
990        # print(output)
991
992        test_schedule = torch.profiler.schedule(
993            skip_first=2, wait=1, warmup=1, active=2, repeat=2
994        )
995        test_schedule_expected_outputs = [
996            ProfilerAction.NONE,
997            ProfilerAction.NONE,
998            ProfilerAction.NONE,
999            ProfilerAction.WARMUP,
1000            ProfilerAction.RECORD,
1001            ProfilerAction.RECORD_AND_SAVE,
1002            ProfilerAction.NONE,
1003            ProfilerAction.WARMUP,
1004            ProfilerAction.RECORD,
1005            ProfilerAction.RECORD_AND_SAVE,
1006            ProfilerAction.NONE,
1007            ProfilerAction.NONE,
1008            ProfilerAction.NONE,
1009            ProfilerAction.NONE,
1010        ]
1011        for step in range(len(test_schedule_expected_outputs)):
1012            self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
1013
1014    def test_kineto_profiler_multiple_steppers(self):
1015        niters = 8
1016        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1017        net = SimpleNet()
1018        opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
1019        opt.zero_grad()
1020        inputs = torch.rand(10)
1021
1022        with profile(activities=supported_activities()):
1023            self.payload(use_cuda=use_cuda)
1024
1025        def optimizer_step():
1026            """This simulates a step() hook in the optimizer"""
1027            KinetoStepTracker.increment_step("yet_another_step")
1028
1029        initial_step = KinetoStepTracker.current_step()
1030
1031        def run_batch():
1032            out = net(inputs)
1033            loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
1034            loss.backward()
1035            opt.step()
1036            # Manually call the hook. TODO: Remove this once we add the
1037            # profiler step hooks in the Optimizer class that will get triggered above.
1038            # See https://github.com/pytorch/pytorch/issues/88446
1039            optimizer_step()
1040
1041        for idx in range(niters):
1042            run_batch()
1043
1044        with profile(
1045            activities=supported_activities(),
1046            schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
1047        ) as p:
1048            for idx in range(niters):
1049                run_batch()
1050                p.step()
1051
1052        self.assertEqual(KinetoStepTracker.current_step(), initial_step + 2 * niters)
1053
1054    def test_export_stacks(self):
1055        with _profile(
1056            with_stack=True,
1057            use_kineto=kineto_available(),
1058            experimental_config=_ExperimentalConfig(verbose=True),
1059        ) as p:
1060            x = torch.randn(10, 10)
1061            y = torch.randn(10, 10)
1062            z = torch.mm(x, y)
1063            z = z + y
1064
1065        with TemporaryFileName(mode="w+") as fname:
1066            p.export_stacks(fname)
1067            with open(fname) as f:
1068                lines = f.readlines()
1069            assert len(lines) > 0, "Empty stacks file"
1070            for line in lines:
1071                is_int = False
1072                try:
1073                    assert int(line.split(" ")[-1]) > 0, "Invalid stacks record"
1074                    is_int = True
1075                except ValueError:
1076                    pass
1077                assert is_int, "Invalid stacks record"
1078
1079    @unittest.skipIf(not kineto_available(), "Kineto is required")
1080    def test_tensorboard_trace_handler(self):
1081        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1082        with _profile(use_cuda=use_cuda, use_kineto=True):
1083            self.payload(use_cuda=use_cuda)
1084
1085        with TemporaryDirectoryName() as dname:
1086            with profile(
1087                activities=[torch.profiler.ProfilerActivity.CPU]
1088                + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []),
1089                schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3),
1090                on_trace_ready=torch.profiler.tensorboard_trace_handler(dname),
1091            ) as p:
1092                for _ in range(18):
1093                    self.payload(use_cuda=use_cuda)
1094                    p.step()
1095
1096            self.assertTrue(os.path.exists(dname))
1097            file_num = 0
1098            for file_name in os.listdir(dname):
1099                parts = file_name.split(".")
1100                self.assertTrue(len(parts) > 4)
1101                self.assertTrue(
1102                    parts[-4].isdigit() and int(parts[-4]) > 0,
1103                    "Wrong tracing file name pattern",
1104                )
1105                self.assertEqual(parts[-3:], ["pt", "trace", "json"])
1106                file_num += 1
1107            self.assertEqual(file_num, 3)
1108
1109        # test case for gzip file format
1110        with TemporaryDirectoryName() as dname:
1111            p = profile(
1112                activities=[torch.profiler.ProfilerActivity.CPU]
1113                + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []),
1114                schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3),
1115                on_trace_ready=torch.profiler.tensorboard_trace_handler(
1116                    dname, use_gzip=True
1117                ),
1118            )
1119            p.start()
1120            for _ in range(18):
1121                self.payload(use_cuda=use_cuda)
1122                p.step()
1123            p.stop()
1124
1125            self.assertTrue(os.path.exists(dname))
1126            file_num = 0
1127            for file_name in os.listdir(dname):
1128                parts = file_name.split(".")
1129                self.assertTrue(len(parts) > 4)
1130                self.assertTrue(
1131                    parts[-5].isdigit() and int(parts[-5]) > 0,
1132                    "Wrong tracing file name pattern",
1133                )
1134                self.assertEqual(parts[-4:], ["pt", "trace", "json", "gz"])
1135                file_num += 1
1136            self.assertEqual(file_num, 3)
1137
1138    @unittest.skipIf(not kineto_available(), "Kineto is required")
1139    def test_profiler_metadata(self):
1140        t1, t2 = torch.ones(1), torch.ones(1)
1141        with profile() as prof:
1142            torch.add(t1, t2)
1143            prof.add_metadata("test_key1", "test_value1")
1144            prof.add_metadata_json("test_key2", "[1,2,3]")
1145
1146        with TemporaryFileName(mode="w+") as fname:
1147            prof.export_chrome_trace(fname)
1148            with open(fname) as f:
1149                trace = json.load(f)
1150                assert "test_key1" in trace
1151                assert trace["test_key1"] == "test_value1"
1152                assert "test_key2" in trace
1153                assert trace["test_key2"] == [1, 2, 3]
1154
1155    def _test_profiler_tracing(self, use_kineto):
1156        with _profile(use_kineto=use_kineto) as prof:
1157            t1, t2 = torch.ones(1), torch.ones(1)
1158            torch.add(t1, t2)
1159
1160        with TemporaryFileName(mode="w+") as fname:
1161            prof.export_chrome_trace(fname)
1162            # read the trace and expect valid json
1163            # if the JSON generated by export_chrome_trace is not valid, this will throw and fail the test.
1164            with open(fname) as f:
1165                json.load(f)
1166
1167        # test empty trace
1168        with _profile(use_kineto=use_kineto) as prof:
1169            pass
1170        # saving an empty trace
1171        with TemporaryFileName(mode="w+") as fname:
1172            prof.export_chrome_trace(fname)
1173            if use_kineto:
1174                with open(fname) as f:
1175                    contents = json.load(f)
1176                    # Some builds may not have logger observer
1177                    # so skip if not
1178                    if "WARNING" in contents:
1179                        found_empty_warning = False
1180                        for warning in contents["WARNING"]:
1181                            if "No Valid Trace Events" in warning:
1182                                found_empty_warning = True
1183                        self.assertTrue(found_empty_warning)
1184
1185        # Same test but for cuda.
1186        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1187        if not use_cuda:
1188            return
1189
1190        device = torch.device("cuda:0")
1191        with _profile(use_cuda=True, use_kineto=use_kineto) as prof:
1192            t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
1193            torch.add(t1, t2)
1194
1195        with TemporaryFileName(mode="w+") as fname:
1196            prof.export_chrome_trace(fname)
1197            # Now validate the json
1198            with open(fname) as f:
1199                json.load(f)
1200
1201    def test_profiler_tracing(self):
1202        self._test_profiler_tracing(False)
1203        if kineto_available():
1204            self._test_profiler_tracing(True)
1205
1206    def test_profiler_op_event_args(self):
1207        torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1208        with _profile(record_shapes=True) as prof:
1209            a = torch.ones((64, 32), dtype=torch.float32)
1210            c = torch.cat([a, a]).sin()
1211        with TemporaryFileName(mode="w+") as fname:
1212            prof.export_chrome_trace(fname)
1213            with open(fname) as f:
1214                j = json.load(f)
1215                op_events = [
1216                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1217                ]
1218                for e in op_events:
1219                    args = e["args"]
1220                    if e["name"] == "aten::ones":
1221                        self.assertEqual(
1222                            args["Input type"],
1223                            ["ScalarList", "Scalar", "", "", "Scalar"],
1224                        )
1225                        self.assertEqual(
1226                            args["Concrete Inputs"], ["[64, 32]", "6", "", "", "False"]
1227                        )
1228
1229                    if e["name"] == "aten::cat":
1230                        self.assertEqual(args["Input Dims"], [[[64, 32], [64, 32]], []])
1231                        self.assertEqual(args["Input type"], ["TensorList", "Scalar"])
1232
1233                    # check that each op has record function id
1234                    self.assertGreaterEqual(
1235                        args.get("Record function id", -1),
1236                        0,
1237                        f"Failed finding record funciont for op = {e}",
1238                    )
1239
1240    def test_profiler_strides(self):
1241        torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1242        base_tensor = torch.randn(1024, dtype=torch.float32)
1243        a = base_tensor.as_strided((16, 16), (17, 1), 0)
1244        b = base_tensor.as_strided((16, 16), (25, 2), 272)
1245        with _profile(record_shapes=True) as prof:
1246            c = torch.add(a, b)
1247
1248        with TemporaryFileName(mode="w+") as fname:
1249            prof.export_chrome_trace(fname)
1250            with open(fname) as f:
1251                j = json.load(f)
1252                op_events = [
1253                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1254                ]
1255                for e in op_events:
1256                    args = e["args"]
1257                    if e["name"] == "aten::add":
1258                        self.assertEqual(args["Input Strides"], [[17, 1], [25, 2], []])
1259
1260    def test_profiler_fwd_bwd_link(self):
1261        with _profile(use_kineto=True) as prof:
1262            t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1263                1, requires_grad=True
1264            )
1265            z = torch.add(t1, t2)
1266            y = torch.ones(1)
1267            loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
1268            loss.backward()
1269        with TemporaryFileName(mode="w+") as fname:
1270            prof.export_chrome_trace(fname)
1271            with open(fname) as f:
1272                j = json.load(f)
1273                events = j["traceEvents"]
1274                ts_to_name = {}
1275                flow_s_to_ts = {}
1276                flow_f_to_ts = {}
1277                for e in events:
1278                    if e["ph"] == "X":
1279                        ts_to_name[e["ts"]] = e["name"]
1280                    if (
1281                        "cat" in e
1282                        and "name" in e
1283                        and e["cat"] == "fwdbwd"
1284                        and e["name"] == "fwdbwd"
1285                    ):
1286                        if e["ph"] == "s":
1287                            flow_s_to_ts[e["id"]] = e["ts"]
1288                        elif e["ph"] == "f":
1289                            flow_f_to_ts[e["id"]] = e["ts"]
1290
1291                self.assertEqual(len(flow_s_to_ts), 2)
1292                self.assertEqual(len(flow_f_to_ts), 2)
1293                self.assertIn(1, flow_s_to_ts)
1294                self.assertIn(1, flow_f_to_ts)
1295                self.assertIn(2, flow_s_to_ts)
1296                self.assertIn(2, flow_f_to_ts)
1297                s_ts_1 = flow_s_to_ts[1]
1298                f_ts_1 = flow_f_to_ts[1]
1299                s_ts_2 = flow_s_to_ts[2]
1300                f_ts_2 = flow_f_to_ts[2]
1301                self.assertTrue(
1302                    all(
1303                        ts in ts_to_name.keys()
1304                        for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]
1305                    )
1306                )
1307                self.assertTrue(
1308                    ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits"
1309                )
1310                self.assertTrue(ts_to_name[s_ts_2] == "aten::add")
1311
1312    def test_profiler_disable_fwd_bwd_link(self):
1313        try:
1314            torch._C._profiler._set_fwd_bwd_enabled_val(False)
1315
1316            with _profile(use_kineto=True) as prof:
1317                t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1318                    1, requires_grad=True
1319                )
1320                z = torch.add(t1, t2)
1321                y = torch.ones(1)
1322                loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
1323                loss.backward()
1324
1325            with TemporaryFileName(mode="w+") as fname:
1326                prof.export_chrome_trace(fname)
1327                with open(fname) as f:
1328                    j = json.load(f)
1329                    events = j["traceEvents"]
1330
1331                    for e in events:
1332                        self.assertNotEqual(e.get("cat", None), "fwdbwd")
1333        finally:
1334            torch._C._profiler._set_fwd_bwd_enabled_val(True)
1335
1336    # This test is broken on Windows, the likely reason is that kineto/CUPTI
1337    # is not supported that particular environment. Once the CI stabilizes
1338    # we can narrow the condition so Windows is checked as well (TODO)
1339    @unittest.skipIf(not kineto_available(), "Kineto is required")
1340    @unittest.skipIf(IS_WINDOWS, "Test does not work on Windows")
1341    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1342    def test_profiler_cuda_sync_events(self):
1343        device = torch.device("cuda:0")
1344        t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
1345
1346        def workload() -> None:
1347            torch.add(t1, t2)
1348            torch.cuda.synchronize()
1349            torch.add(t1, t2)
1350
1351        def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None:
1352            with _profile(
1353                use_kineto=True,
1354                use_cuda=True,
1355                experimental_config=exp_config,
1356            ) as prof:
1357                workload()
1358
1359            with TemporaryFileName(mode="w+") as fname:
1360                # fname = "/tmp/kineto_out.json"
1361                prof.export_chrome_trace(fname)
1362                with open(fname) as f:
1363                    j = json.load(f)
1364                    cats = {e.get("cat", None) for e in j["traceEvents"]}
1365            self.assertTrue(
1366                "cuda_sync" in cats,
1367                "Expected to find cuda_sync event" f" found = {cats}",
1368            )
1369
1370        print("Testing enable_cuda_sync_events in _ExperimentalConfig")
1371        trace_and_check(exp_config=_ExperimentalConfig(enable_cuda_sync_events=True))
1372
1373        print("Testing _profiler._set_cuda_sync_enabled_val()")
1374        try:
1375            torch._C._profiler._set_cuda_sync_enabled_val(True)
1376            trace_and_check(exp_config=None)
1377        finally:
1378            torch._C._profiler._set_cuda_sync_enabled_val(False)
1379
1380    def test_profiler_type(self):
1381        profiler_type = torch._C._autograd._profiler_type
1382        ActiveProfilerType = torch._C._profiler.ActiveProfilerType
1383        self.assertEqual(profiler_type(), ActiveProfilerType.NONE)
1384
1385        # Autograd profiler
1386        with _profile_legacy():
1387            self.assertEqual(profiler_type(), ActiveProfilerType.LEGACY)
1388
1389        # Kineto profiler
1390        with profile():
1391            self.assertEqual(profiler_type(), ActiveProfilerType.KINETO)
1392
1393    def test_profiler_correlation_id(self):
1394        """
1395        We expect the correlation_id to be unique across multiple invokation of the profiler,
1396        So we will reuse id_uniqueness_set.
1397        """
1398        id_uniqueness_set = set()
1399        model = torch.nn.Sequential(
1400            nn.Conv2d(16, 33, 18),
1401            nn.ReLU(),
1402            nn.Linear(243, 243),
1403            nn.ReLU(),
1404        )
1405        inputs = torch.randn(40, 16, 18, 260)
1406        uint32_max = 2**32 - 1
1407        for i in range(5):
1408            with profile() as prof:
1409                model(inputs)
1410            for event in prof.profiler.kineto_results.events():
1411                corr_id = event.correlation_id()
1412                if (corr_id) and event.device_type() == DeviceType.CPU:
1413                    self.assertTrue(corr_id not in id_uniqueness_set)
1414                    id_uniqueness_set.add(corr_id)
1415                    self.assertTrue(corr_id < uint32_max)
1416
1417    def test_nested_tensor_with_shapes(self):
1418        a = torch.randn(4, 4)
1419        b = torch.randn(4, 4)
1420        c = torch.randn(4, 4)
1421        inp = torch.nested.nested_tensor([a, b])
1422        with torch.profiler.profile(record_shapes=True) as prof:
1423            torch.nn.functional.linear(inp, c, None)
1424        for e in prof.events():
1425            if e.name in ("aten::mm", "aten::addmm"):
1426                # intentionally vague tests to protect against possible future changes
1427                # of mm to addmm or other impl, or changing internal order of args
1428                self.assertTrue(len(e.input_shapes) > 0)
1429                self.assertTrue(len(e.input_shapes[0]) > 0)
1430
1431    @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
1432    @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"})
1433    def test_kineto_profiler_with_environment_variable(self):
1434        script = """
1435import torch
1436import torch.nn as nn
1437from torch.profiler import supported_activities, profile
1438from torch.autograd.profiler import KinetoStepTracker
1439
1440class SimpleNet(nn.Module):
1441    def __init__(self) -> None:
1442        super().__init__()
1443        self.fc1 = nn.Linear(10, 5)
1444        self.fc2 = nn.Linear(5, 2)
1445
1446    def forward(self, x):
1447        return self.fc2(self.fc1(x))
1448
1449
1450def payload(use_cuda=False):
1451    x = torch.randn(10, 10)
1452    if use_cuda:
1453        x = x.cuda()
1454    y = torch.randn(10, 10)
1455    if use_cuda:
1456        y = y.cuda()
1457    z = torch.mm(x, y)
1458    z = z + y
1459    if use_cuda:
1460        z = z.cpu()
1461
1462niters = 8
1463use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1464net = SimpleNet()
1465opt = torch.optim.SGD(net.parameters(), lr=0.01)
1466opt.zero_grad()
1467inputs = torch.rand(10)
1468
1469with profile(activities=supported_activities()):
1470    payload(use_cuda=use_cuda)
1471
1472initial_step = KinetoStepTracker.current_step()
1473
1474def run_batch():
1475    out = net(inputs)
1476    loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
1477    loss.backward()
1478    opt.step()
1479
1480for _ in range(niters):
1481    run_batch()
1482
1483with profile(
1484    activities=supported_activities(),
1485    schedule=torch.profiler.schedule(
1486        wait=1,
1487        warmup=1,
1488        active=2),
1489) as p:
1490    for _ in range(niters):
1491        run_batch()
1492        p.step()
1493assert KinetoStepTracker.current_step() == initial_step + 2 * niters
1494"""
1495        try:
1496            subprocess.check_output(
1497                [sys.executable, "-W", "always", "-c", script],
1498                cwd=os.path.dirname(os.path.realpath(__file__)),
1499            )
1500        except subprocess.CalledProcessError as e:
1501            if e.returncode != 0:
1502                self.assertTrue(
1503                    False,
1504                    "Kineto is not working properly with the Dynolog environment variable",
1505                )
1506
1507    def test_concrete_inputs_profiling(self):
1508        x = torch.rand(2, 6)
1509        with profile(record_shapes=True) as p:
1510            y = x.as_strided([4, 3], [1, 4])
1511
1512        found = False
1513        for e in p.events():
1514            if e.name in ("aten::as_strided"):
1515                found = True
1516                self.assertTrue(len(e.input_shapes) > 0)
1517                self.assertTrue(len(e.concrete_inputs) > 0)
1518                self.assertEqual([2, 6], e.input_shapes[0])
1519                self.assertEqual([4, 3], e.concrete_inputs[1])
1520                self.assertEqual([1, 4], e.concrete_inputs[2])
1521
1522        self.assertTrue(found, "Expected to find aten::as_strided but did not")
1523
1524    def test_concrete_inputs_profiling_toggling(self):
1525        try:
1526            for before, after in [(True, False), (False, True)]:
1527                x = torch.rand(2, 6)
1528                torch._C._profiler._set_record_concrete_inputs_enabled_val(before)
1529                with profile(record_shapes=True) as p:
1530                    y = x.as_strided([4, 3], [1, 4])
1531                    torch._C._profiler._set_record_concrete_inputs_enabled_val(after)
1532
1533                found = False
1534                for e in p.events():
1535                    if e.name in ("aten::as_strided"):
1536                        found = True
1537                        self.assertTrue(len(e.input_shapes))
1538
1539                self.assertTrue(found, "Expected to find aten::as_strided but did not")
1540        finally:
1541            torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1542
1543    def test_record_function_fast(self):
1544        x, y = (torch.rand((4, 4)) for _ in range(2))
1545        with profile(record_shapes=True) as p:
1546            for _ in range(4):
1547                # Test first with no optional args
1548                with torch._C._profiler._RecordFunctionFast("add_test_fast_rf1"):
1549                    x.add(y)
1550
1551        self.assertGreaterEqual(
1552            len([e for e in p.events() if e.name == "add_test_fast_rf1"]), 4
1553        )
1554        for e in p.events():
1555            if e.name == "add_test_fast_rf1":
1556                self.assertTrue(e.input_shapes == [])
1557                self.assertTrue(e.kwinputs == {})
1558        with profile(record_shapes=True) as p:
1559            # add optional args
1560            cm = torch._C._profiler._RecordFunctionFast(
1561                "add_test_fast_rf2", [x, y], {"stream": 0, "grid": "lambda x : x + 1"}
1562            )
1563            for _ in range(4):
1564                with cm:
1565                    x.add(y)
1566
1567        self.assertGreaterEqual(
1568            len([e for e in p.events() if e.name == "add_test_fast_rf2"]), 4
1569        )
1570
1571        for e in p.events():
1572            if e.name == "add_test_fast_rf2":
1573                self.assertTrue(e.input_shapes == [[4, 4], [4, 4]])
1574                self.assertTrue(e.kwinputs == {"stream": 0, "grid": "lambda x : x + 1"})
1575
1576        with profile(record_shapes=True) as p:
1577            cm = torch._C._profiler._RecordFunctionFast(
1578                "add_test_fast_rf3", input_values=["hi"], keyword_values={"hi": "hello"}
1579            )
1580            for _ in range(4):
1581                try:
1582                    with cm:
1583                        x.add(y)
1584                        raise ValueError
1585                        x.relu()
1586                except ValueError:
1587                    pass
1588
1589        self.assertGreaterEqual(
1590            len([e for e in p.events() if e.name == "add_test_fast_rf3"]), 4
1591        )
1592        self.assertFalse(any((e.name and "relu" in e.name) for e in p.events()))
1593
1594        for e in p.events():
1595            if e.name == "add_test_fast_rf3":
1596                self.assertTrue(e.input_shapes == [[]])
1597
1598        with profile() as p:
1599            for _ in range(4):
1600                with torch._C._profiler._RecordFunctionFast(
1601                    "add_test_fast_rf4", [x, y]
1602                ):
1603                    x.add(y)
1604                    with torch._C._profiler._RecordFunctionFast("add_test_fast_rf5"):
1605                        x.relu()
1606
1607        self.assertGreaterEqual(
1608            len([e for e in p.events() if e.name == "add_test_fast_rf4"]), 4
1609        )
1610
1611        for e in p.events():
1612            if e.name == "add_test_fast_rf4":
1613                self.assertTrue(e.input_shapes == [])
1614
1615        self.assertGreaterEqual(
1616            len([e for e in p.events() if e.name == "add_test_fast_rf5"]), 4
1617        )
1618
1619        with profile(record_shapes=True) as p:
1620            # test optional args with tuple
1621            cm = torch._C._profiler._RecordFunctionFast(
1622                "add_test_fast_rf6",
1623                (
1624                    x,
1625                    y,
1626                ),
1627            )
1628            for _ in range(4):
1629                with cm:
1630                    x.add(y)
1631
1632        self.assertGreaterEqual(
1633            len([e for e in p.events() if e.name == "add_test_fast_rf6"]), 4
1634        )
1635
1636        for e in p.events():
1637            if e.name == "add_test_fast_rf6":
1638                self.assertTrue(e.input_shapes == [[4, 4], [4, 4]])
1639
1640    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1641    def test_profiler_op_event_kwargs(self):
1642        x, y = (torch.rand((4, 4)) for _ in range(2))
1643        with profile(record_shapes=True) as p:
1644            cm = torch._C._profiler._RecordFunctionFast(
1645                "add_test_kwinputs",
1646                [x, y],
1647                {"stream": 0, "grid": "lambda x : x + 1", "debug": 'debug"'},
1648            )
1649            for _ in range(4):
1650                with cm:
1651                    x.add(y)
1652        with TemporaryFileName(mode="w+") as fname:
1653            p.export_chrome_trace(fname)
1654            with open(fname) as f:
1655                j = json.load(f)
1656                op_events = [
1657                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1658                ]
1659                for e in op_events:
1660                    if e["name"] == "add_test_kwinputs":
1661                        args = e["args"]
1662                        self.assertTrue("stream" in args)
1663                        self.assertTrue("grid" in args)
1664                        self.assertTrue(args["stream"] == 0)
1665                        self.assertTrue(args["grid"] == "lambda x : x + 1")
1666                        self.assertTrue(args["debug"] == "None")
1667
1668    def test_is_profiler_enabled(self):
1669        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1670
1671        with profile() as p:
1672            self.assertTrue(torch.autograd.profiler._is_profiler_enabled)
1673
1674        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1675
1676        with torch.autograd.profiler.profile() as p:
1677            self.assertTrue(torch.autograd.profiler._is_profiler_enabled)
1678
1679        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1680
1681    def test_guarded_record_function_fast(self):
1682        x, y = (torch.rand((4, 4)) for _ in range(2))
1683
1684        with profile() as p:
1685            cm = torch._C._profiler._RecordFunctionFast("guarded_rff")
1686            for _ in range(4):
1687                if torch.autograd.profiler._is_profiler_enabled:
1688                    with cm:
1689                        x.add(y)
1690                else:
1691                    x.add(y)
1692
1693        self.assertGreaterEqual(
1694            len([e for e in p.events() if e.name == "guarded_rff"]), 4
1695        )
1696
1697    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1698    def test_event_list(self):
1699        # AFAIK event list is part of legacy profiler and/or used when kineto is not available.
1700        # This test has basic sanity checks to test against obvious regressions.
1701        x, y = (torch.rand((4, 4), requires_grad=True, device="cuda") for _ in range(2))
1702        with profile(with_stack=True) as p:
1703            z = (x @ y).relu().sum()
1704            z.backward()
1705
1706        event_list = torch.autograd.profiler_util.EventList(p.events())
1707        # event_list._build_tree()
1708
1709        with TemporaryFileName(mode="w+") as fname:
1710            event_list.export_chrome_trace(fname)
1711            with open(fname) as f:
1712                json.load(f)
1713
1714        event_list.table()
1715
1716    def _check_all_gpu_present(self, gpu_dict, max_gpu_count):
1717        for i in range(0, max_gpu_count):
1718            self.assertEqual(gpu_dict["GPU " + str(i)], 1)
1719
1720    # Do json sanity testing. Checks that all events are between profiler start and end
1721    # also checks to see that GPU values are present in trace if cuda is used
1722    def _validate_basic_json(self, traceEvents, cuda_available=False):
1723        MAX_GPU_COUNT = 8
1724        PROFILER_IDX = -4
1725        RECORD_END = -1
1726        RECORD_START = -2
1727        traceEventProfiler = traceEvents[PROFILER_IDX]
1728
1729        self.assertTrue(traceEventProfiler["name"] == "PyTorch Profiler (0)")
1730        self.assertTrue(traceEvents[RECORD_END]["name"] == "Record Window End")
1731        self.assertTrue(
1732            traceEvents[RECORD_START]["name"] == "Iteration Start: PyTorch Profiler"
1733        )
1734        # check that the profiler starts/ends within the record interval
1735        self.assertGreaterEqual(
1736            traceEventProfiler["ts"],
1737            traceEvents[RECORD_START]["ts"],
1738            "Profiler starts before record!",
1739        )
1740        self.assertLessEqual(
1741            traceEventProfiler["ts"] + traceEventProfiler["dur"],
1742            traceEvents[RECORD_END]["ts"],
1743            "Profiler ends after record end!",
1744        )
1745
1746        gpu_dict = collections.defaultdict(int)
1747        for i, traceEvent in enumerate(traceEvents):
1748            if (
1749                i == len(traceEvents) + RECORD_END
1750                or i == len(traceEvents) + RECORD_START
1751            ):
1752                continue
1753            # make sure all valid trace events are within the bounds of the profiler
1754            if "ts" in traceEvent:
1755                self.assertGreaterEqual(
1756                    traceEvent["ts"],
1757                    traceEventProfiler["ts"],
1758                    "Trace event is out of bounds",
1759                )
1760            # some python events seem to go a little past record end probably because
1761            # of some clock inaccuracies so just compare events ending to RECORD_END
1762            if "dur" in traceEvent:
1763                self.assertLessEqual(
1764                    traceEvent["ts"] + traceEvent["dur"],
1765                    traceEvents[RECORD_END]["ts"],
1766                    "Trace event ends too late!",
1767                )
1768            gpu_value = traceEvent.get("args", {}).get("labels", None)
1769            if gpu_value and "GPU" in gpu_value:
1770                gpu_dict[gpu_value] += 1
1771                # Max PID offset is 5M, based from pytorch/kineto include header:
1772                # https://github.com/pytorch/kineto/blob/8681ff11e1fa54da39023076c5c43eddd87b7a8a/libkineto/include/output_base.h#L35
1773                kExceedMaxPid = 5000000
1774                self.assertTrue(
1775                    traceEvents[i + 1]["args"]["sort_index"]
1776                    == kExceedMaxPid + int(gpu_value.split()[1])
1777                )
1778
1779        # TODO add checking gpu count if cpuOnly_ is true or not
1780
1781    def _test_chrome_trace_basic_helper(self, with_cuda=False):
1782        if with_cuda:
1783            device = "cuda"
1784        else:
1785            device = "cpu"
1786        x, y = (torch.rand(4, 4).to(device) for _ in range(2))
1787
1788        with profile(with_stack=True) as p:
1789            torch.add(x, y)
1790        with TemporaryFileName(mode="w+") as fname:
1791            p.export_chrome_trace(fname)
1792            with open(fname) as f:
1793                report = json.load(f)
1794                self._validate_basic_json(report["traceEvents"], with_cuda)
1795
1796    @unittest.skipIf(not kineto_available(), "Kineto is required")
1797    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1798    def test_basic_chrome_trace(self):
1799        self._test_chrome_trace_basic_helper()
1800        if torch.cuda.is_available():
1801            self._test_chrome_trace_basic_helper(with_cuda=True)
1802
1803    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1804    def test_profiler_time_scale(self):
1805        MARGIN_ERROR = 0.5
1806        SEC_TO_US = 1000 * 1000
1807        WAIT_TIME = 10
1808        with profile() as p:
1809            with torch.profiler.record_function("test_span"):
1810                for i in range(WAIT_TIME):
1811                    torch.rand(4, 4)
1812                    time.sleep(1)
1813        events = p.events()
1814
1815        # make sure function events are scaled appropriately
1816        self.assertTrue(events[0].name == "test_span")
1817        test_span = events[0]
1818        self.assertGreaterEqual(
1819            test_span.cpu_time / SEC_TO_US,
1820            WAIT_TIME - MARGIN_ERROR,
1821            "event out of range",
1822        )
1823        self.assertLessEqual(
1824            test_span.cpu_time / SEC_TO_US,
1825            WAIT_TIME + MARGIN_ERROR,
1826            "event out of range",
1827        )
1828
1829        # make sure tracing is scaled appropriately
1830        with TemporaryFileName(mode="w+") as fname:
1831            p.export_chrome_trace(fname)
1832            with open(fname) as f:
1833                report = json.load(f)
1834            events = report["traceEvents"]
1835            for event in events:
1836                if event["name"] == "test_span":
1837                    self.assertGreaterEqual(
1838                        event["dur"] / SEC_TO_US,
1839                        WAIT_TIME - MARGIN_ERROR,
1840                        "profiling out of range",
1841                    )
1842                    self.assertLessEqual(
1843                        event["dur"] / SEC_TO_US,
1844                        WAIT_TIME + MARGIN_ERROR,
1845                        "profiling out of range",
1846                    )
1847
1848    def _schedule_helper(self, warmup, active, repeat, acc_events=True):
1849        with profile(
1850            schedule=torch.profiler.schedule(
1851                skip_first=0,
1852                wait=0,
1853                warmup=warmup,
1854                active=active,
1855                repeat=repeat,
1856            ),
1857            acc_events=acc_events,
1858        ) as prof:
1859            for i in range(100):
1860                torch.add(1, 2)
1861                prof.step()
1862        # print(prof.key_averages())
1863        for ev in prof.key_averages():
1864            if ev.key == "aten::add":
1865                return ev.count
1866        return 0
1867
1868    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1869    def test_schedule_function_count(self):
1870        self.assertEqual(self._schedule_helper(warmup=0, active=1, repeat=1), 1)
1871        self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=0), 100)
1872        self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=10), 50)
1873        self.assertEqual(self._schedule_helper(warmup=1, active=5, repeat=0), 83)
1874        self.assertEqual(self._schedule_helper(warmup=10, active=10, repeat=4), 40)
1875        self.assertEqual(self._schedule_helper(warmup=50, active=1, repeat=0), 1)
1876        self.assertEqual(
1877            self._schedule_helper(warmup=0, active=5, repeat=0, acc_events=False), 0
1878        )
1879        self.assertEqual(
1880            self._schedule_helper(warmup=10, active=10, repeat=4, acc_events=False), 10
1881        )
1882
1883    def _step_helper_func(self, prof):
1884        time.sleep(0.1)
1885        torch.randn(1, 3, 224, 224)
1886        prof.step()
1887
1888    def _partial_overlap(self, prof_step, step_helper_func):
1889        p_start = prof_step["ts"]
1890        p_end = prof_step["ts"] + prof_step["dur"]
1891        h_start = step_helper_func["ts"]
1892        h_end = step_helper_func["ts"] + step_helper_func["dur"]
1893
1894        if p_start < h_start and p_end < h_end and p_end > h_start:
1895            return True
1896        if p_start > h_start and p_start < h_end and p_end > h_end:
1897            return True
1898        return False
1899
1900    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1901    def test_cpu_annotation_overlap(self):
1902        with torch.profiler.profile(
1903            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
1904            record_shapes=True,
1905            with_stack=True,
1906            schedule=torch.profiler.schedule(wait=0, warmup=0, active=5, repeat=1),
1907        ) as prof:
1908            for i in range(5):
1909                self._step_helper_func(prof)
1910        with TemporaryFileName(mode="w+") as fname:
1911            prof.export_chrome_trace(fname)
1912            prof_steps = []
1913            step_helper_funcs = []
1914            with open(fname) as f:
1915                report = json.load(f)
1916                for event in report["traceEvents"]:
1917                    if "ProfilerStep" in event["name"]:
1918                        prof_steps.append(event)
1919                    if "step_helper_func" in event["name"]:
1920                        step_helper_funcs.append(event)
1921            self.assertEqual(len(prof_steps), 5)
1922            self.assertEqual(len(step_helper_funcs), 5)
1923            for i in range(0, len(step_helper_funcs)):
1924                for j in range(0, len(step_helper_funcs)):
1925                    self.assertTrue(
1926                        not self._partial_overlap(prof_steps[i], step_helper_funcs[j])
1927                    )
1928
1929    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1930    def test_user_annotation(self):
1931        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1932        with profile(activities=supported_activities()) as p:
1933            with torch.profiler.record_function("test_user_annotation"):
1934                self.payload(use_cuda=use_cuda)
1935
1936        for evt in p.key_averages():
1937            if evt.key == "test_user_annotation":
1938                self.assertTrue(evt.is_user_annotation)
1939            else:
1940                self.assertFalse(evt.is_user_annotation)
1941
1942    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1943    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1944    def test_dynamic_toggle(self):
1945        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p:
1946            with torch.profiler.record_function("test_user_annotation"):
1947                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1948                torch.add(x, y)
1949
1950        self.assertTrue(any("aten" in e.name for e in p.events()))
1951
1952        self.assertTrue(any("cuda" in e.name for e in p.events()))
1953
1954        self.assertTrue(any("kernel" in e.name for e in p.events()))
1955
1956        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p1:
1957            p1.toggle_collection_dynamic(False, [ProfilerActivity.CUDA])
1958            with torch.profiler.record_function("test_user_annotation"):
1959                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1960                torch.add(x, y)
1961
1962        self.assertTrue(any("aten" in e.name for e in p1.events()))
1963
1964        self.assertTrue(all("cuda" not in e.name for e in p1.events()))
1965
1966        self.assertTrue(all("kernel" not in e.name for e in p1.events()))
1967
1968        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p2:
1969            p2.toggle_collection_dynamic(
1970                False, [ProfilerActivity.CUDA, ProfilerActivity.CPU]
1971            )
1972            with torch.profiler.record_function("test_user_annotation"):
1973                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1974                torch.add(x, y)
1975        self.assertTrue(len(p2.events()) == 0)
1976
1977    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1978    def test_lazy_build_tree(self):
1979        with profile() as p:
1980            self.payload()
1981
1982        stats = p._stats()
1983        # Test that the tree is not built
1984        self.assertEqual(stats.function_events_build_tree_call_duration_us, 0)
1985        self.assertEqual(stats.number_of_events, 0)
1986
1987        # Test that the tree is built on demand
1988        p.events()
1989        self.assertGreater(stats.function_events_build_tree_call_duration_us, 0)
1990        self.assertGreater(stats.number_of_events, 0)
1991
1992
1993class SimpleNet(nn.Module):
1994    def __init__(self) -> None:
1995        super().__init__()
1996        self.fc1 = nn.Linear(10, 5)
1997        self.fc2 = nn.Linear(5, 2)
1998
1999    def forward(self, x):
2000        return self.fc2(self.fc1(x))
2001
2002
2003@dataclass(frozen=True)
2004class MockKinetoEvent:
2005    _name: str
2006    _start_us: int
2007    _duration_us: int
2008    _linked_correlation_id: int
2009    _device_type: int
2010
2011    @property
2012    def name(self) -> str:
2013        return self._name
2014
2015    def start_ns(self) -> int:
2016        return self._start_us * 1000
2017
2018    def duration_ns(self) -> int:
2019        return self._duration_us * 1000
2020
2021    def linked_correlation_id(self) -> int:
2022        return self._linked_correlation_id
2023
2024    def device_type(self) -> DeviceType:
2025        return DeviceType.CUDA if self._device_type == 1 else DeviceType.CPU
2026
2027
2028@dataclass(frozen=True)
2029class MockProfilerEvent:
2030    _name: str
2031    id: int
2032    start_time_ns: int
2033    duration_time_ns: int
2034    correlation_id: int = 0
2035    children: List["MockProfilerEvent"] = field(default_factory=list)
2036    parent: Optional["MockProfilerEvent"] = None
2037
2038    @property
2039    def end_time_ns(self):
2040        return self.start_time_ns + self.duration_time_ns
2041
2042    @property
2043    def name(self) -> str:
2044        return self._name
2045
2046    def __post__init__(self, parent, children):
2047        object.__setattr__(self, "parent", parent)
2048        object.__setattr__(self, "children", children)
2049
2050
2051class MockNode:
2052    def __init__(self, name, children) -> None:
2053        self.name = name
2054        self.children = [MockNode(name, i) for name, i in children.items()]
2055
2056
2057class TestExperimentalUtils(TestCase):
2058    def make_tree(self) -> List[MockNode]:
2059        tree = {
2060            "root_0": {
2061                "1": {"2": {}},
2062                "3": {
2063                    "4": {},
2064                    "5": {},
2065                },
2066            },
2067            "root_1": {
2068                "6": {},
2069                "7": {},
2070                "8": {
2071                    "9": {"10": {}},
2072                },
2073            },
2074        }
2075        return [MockNode(name, i) for name, i in tree.items()]
2076
2077    def test_dfs(self) -> None:
2078        self.assertEqual(
2079            " ".join(i.name for i in _utils.traverse_dfs(self.make_tree())),
2080            "root_0 1 2 3 4 5 root_1 6 7 8 9 10",
2081        )
2082
2083    def test_bfs(self) -> None:
2084        self.assertEqual(
2085            " ".join(i.name for i in _utils.traverse_bfs(self.make_tree())),
2086            "root_0 root_1 1 3 6 7 8 2 4 5 9 10",
2087        )
2088
2089    @staticmethod
2090    def generate_mock_profile():
2091        cuda_events = [
2092            MockKinetoEvent("cudaLaunchKernel", 400, 100, 1, 0),
2093            MockKinetoEvent("cudaLaunchKernel", 500, 100, 2, 0),
2094            MockKinetoEvent("cudaLaunchKernel", 600, 100, 3, 0),
2095            MockKinetoEvent("cudaLaunchKernel", 700, 100, 4, 0),
2096            MockKinetoEvent("cudaLaunchKernel", 800, 100, 5, 0),
2097            MockKinetoEvent("cudaLaunchKernel", 1500, 100, 6, 0),
2098            MockKinetoEvent("GPU", 900, 100, 1, 1),
2099            MockKinetoEvent("GPU", 1000, 100, 2, 1),
2100            MockKinetoEvent("GPU", 1100, 100, 3, 1),
2101            MockKinetoEvent("GPU", 1200, 100, 4, 1),
2102            MockKinetoEvent("GPU", 1300, 100, 5, 1),
2103            MockKinetoEvent("GPU", 1700, 100, 6, 1),
2104        ]
2105        cpu_events = [
2106            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 1, 0, 100000),
2107            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 2, 100000, 100000),
2108            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 3, 200000, 100000),
2109            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 4, 300000, 100000),
2110            MockProfilerEvent("CPU (After cudaLaunchKernel)", 5, 400000, 100000),
2111            MockProfilerEvent("CPU (After cudaLaunchKernel)", 6, 500000, 100000),
2112            MockProfilerEvent("CPU (After cudaLaunchKernel)", 7, 600000, 100000),
2113            MockProfilerEvent("CPU (After cudaLaunchKernel)", 8, 700000, 100000),
2114            MockProfilerEvent("CPU (After GPU)", 9, 800000, 100000),
2115            MockProfilerEvent("CPU (After GPU)", 10, 900000, 100000),
2116            MockProfilerEvent("CPU (After GPU)", 11, 1100000, 100000),
2117            MockProfilerEvent("CPU (After GPU)", 12, 1200000, 500000),
2118        ]
2119
2120        profiler = unittest.mock.Mock()
2121        profiler.kineto_results = unittest.mock.Mock()
2122        profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events)
2123        profiler.kineto_results.experimental_event_tree = unittest.mock.Mock(
2124            return_value=cpu_events
2125        )
2126        return profiler
2127
2128    @staticmethod
2129    def load_mock_profile():
2130        accept = expecttest.ACCEPT
2131        json_file_path = os.path.join(
2132            os.path.dirname(os.path.realpath(__file__)),
2133            "profiler_utils_mock_events.json",
2134        )
2135        if accept and torch.cuda.is_available():
2136
2137            def garbage_code(x):
2138                for i in range(5):
2139                    x[0, i] = i
2140
2141            x = torch.ones((4096, 4096), device="cuda")
2142            x = x @ x
2143            with profile(
2144                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
2145                record_shapes=True,
2146                with_stack=True,
2147            ) as prof:
2148                for _ in range(5):
2149                    x = x @ x
2150                garbage_code(x)
2151                for _ in range(5):
2152                    x = x @ x
2153
2154            kineto_events = [
2155                {
2156                    "_name": e.name,
2157                    "_start_ns": e.start_ns(),
2158                    "_duration_ns": e.duration_ns(),
2159                    "_linked_correlation_id": e.linked_correlation_id(),
2160                    "_device_type": 1 if e.device_type() == DeviceType.CUDA else 0,
2161                }
2162                for e in prof.profiler.kineto_results.events()
2163            ]
2164
2165            def EventTreeDFS(event_tree):
2166                from collections import deque
2167
2168                stack = deque(event_tree)
2169                while stack:
2170                    curr_event = stack.pop()
2171                    yield curr_event
2172                    for child_event in curr_event.children:
2173                        stack.append(child_event)
2174
2175            profiler_events = [
2176                {
2177                    "_name": e.name,
2178                    "id": e.id,
2179                    "start_time_ns": e.start_time_ns,
2180                    "duration_time_ns": e.duration_time_ns,
2181                    "correlation_id": e.correlation_id,
2182                    "children": [child.id for child in e.children],
2183                    "parent": e.parent.id if e.parent else None,
2184                }
2185                for e in EventTreeDFS(
2186                    prof.profiler.kineto_results.experimental_event_tree()
2187                )
2188            ]
2189
2190            with open(json_file_path, "w") as f:
2191                json.dump([kineto_events, profiler_events], f)
2192
2193        assert os.path.exists(json_file_path)
2194        with open(json_file_path) as f:
2195            kineto_events, profiler_events = json.load(f)
2196
2197        cuda_events = [MockKinetoEvent(*event.values()) for event in kineto_events]
2198        cpu_events = []
2199        id_map = {}
2200        for e in profiler_events:
2201            event = MockProfilerEvent(**e)
2202            id_map[event.id] = event
2203            cpu_events.append(event)
2204        for event in cpu_events:
2205            parent = None if event.parent is None else id_map[event.parent]
2206            children = [id_map[child] for child in event.children]
2207            event.__post__init__(parent, children)
2208        cpu_events = [event for event in cpu_events if event.parent is None]
2209        profiler = unittest.mock.Mock()
2210        profiler.kineto_results = unittest.mock.Mock()
2211        profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events)
2212        profiler.kineto_results.experimental_event_tree = unittest.mock.Mock(
2213            return_value=cpu_events
2214        )
2215        return profiler
2216
2217    def test_utils_compute_self_time(self):
2218        with profile() as prof:
2219            t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
2220                1, requires_grad=True
2221            )
2222            z = torch.add(t1, t2)
2223            y = torch.ones(1)
2224            loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
2225            loss.backward()
2226        basic_eval = _utils.BasicEvaluation(prof.profiler)
2227        metrics = basic_eval.metrics
2228        self.assertTrue(len(metrics) > 0)
2229        for event_key, event_metrics in metrics.items():
2230            self.assertEqual(
2231                event_metrics.self_time_ns,
2232                event_key.event.duration_time_ns
2233                - sum(child.duration_time_ns for child in event_key.event.children),
2234            )
2235
2236    def test_utils_intervals_overlap(self):
2237        event = _utils.EventKey(MockProfilerEvent("Event 1", 1, 5, 5))
2238        intervals = [
2239            _utils.Interval(0, 9),
2240            _utils.Interval(1, 2),
2241            _utils.Interval(2, 3),
2242            _utils.Interval(3, 4),
2243            _utils.Interval(4, 5),
2244            _utils.Interval(8, 12),
2245        ]
2246        print(event.intervals_overlap(intervals))
2247        self.assertEqual(event.intervals_overlap(intervals), 5)
2248
2249    def test_utils_compute_queue_depth(self):
2250        def format_queue_depth(queue_depth_list, events):
2251            res = ""
2252            for data, event in zip(queue_depth_list, events):
2253                res += f"{data.queue_depth} [{event.name}]\n"
2254            return res
2255
2256        # We have to use Mock because time series data is too flaky to test
2257        profiler = self.generate_mock_profile()
2258        basic_evaluation = _utils.BasicEvaluation(profiler)
2259        self.assertExpectedInline(
2260            format_queue_depth(
2261                basic_evaluation.queue_depth_list, basic_evaluation.cuda_events
2262            ),
2263            """\
22641 [cudaLaunchKernel]
22652 [cudaLaunchKernel]
22663 [cudaLaunchKernel]
22674 [cudaLaunchKernel]
22685 [cudaLaunchKernel]
22694 [GPU]
22703 [GPU]
22712 [GPU]
22721 [GPU]
22730 [GPU]
22741 [cudaLaunchKernel]
22750 [GPU]
2276""",
2277        )
2278        self.assertExpectedInline(
2279            format_queue_depth(
2280                [basic_evaluation.metrics[k] for k in basic_evaluation.event_keys],
2281                basic_evaluation.events,
2282            ),
2283            """\
22840 [CPU (Before cudaLaunchKernel)]
22850 [CPU (Before cudaLaunchKernel)]
22860 [CPU (Before cudaLaunchKernel)]
22870 [CPU (Before cudaLaunchKernel)]
22881 [CPU (After cudaLaunchKernel)]
22892 [CPU (After cudaLaunchKernel)]
22903 [CPU (After cudaLaunchKernel)]
22914 [CPU (After cudaLaunchKernel)]
22925 [CPU (After GPU)]
22934 [CPU (After GPU)]
22942 [CPU (After GPU)]
22951 [CPU (After GPU)]
2296""",
2297        )
2298
2299    def test_utils_compute_queue_depth_when_no_cuda_events(self):
2300        # For traces with only cpu events, we expect empty queue depth list
2301        x = torch.ones((1024, 1024))
2302        with profile() as prof:
2303            for _ in range(5):
2304                x = x @ x
2305        basic_evaluation = _utils.BasicEvaluation(prof.profiler)
2306        self.assertFalse(basic_evaluation.compute_queue_depth())
2307
2308    def test_utils_compute_idle_time(self):
2309        profiler = self.generate_mock_profile()
2310        basic_evaluation = _utils.BasicEvaluation(profiler)
2311        expected_output = "\n".join(
2312            [
2313                f"{basic_evaluation.metrics[event_key].idle_time_ns} [{event_key.event.name}]"
2314                for event_key in basic_evaluation.event_keys
2315            ]
2316        )
2317        self.assertExpectedInline(
2318            expected_output,
2319            """\
2320100000 [CPU (Before cudaLaunchKernel)]
2321100000 [CPU (Before cudaLaunchKernel)]
2322100000 [CPU (Before cudaLaunchKernel)]
2323100000 [CPU (Before cudaLaunchKernel)]
23240 [CPU (After cudaLaunchKernel)]
23250 [CPU (After cudaLaunchKernel)]
23260 [CPU (After cudaLaunchKernel)]
23270 [CPU (After cudaLaunchKernel)]
23280 [CPU (After GPU)]
23290 [CPU (After GPU)]
23300 [CPU (After GPU)]
2331100000 [CPU (After GPU)]""",
2332        )
2333
2334    @unittest.skipIf(IS_JETSON, "JSON not behaving as expected on Jetson")
2335    def test_utils_get_optimizable_events(self):
2336        basic_evaluation = _utils.BasicEvaluation(self.load_mock_profile())
2337        optimizable_events = basic_evaluation.get_optimizable_events(
2338            2, print_enable=False
2339        )
2340        expected_output = "\n".join(
2341            [f"{event_key.event.name}" for event_key in optimizable_events]
2342        )
2343        self.assertExpectedInline(
2344            expected_output,
2345            """\
2346<built-in function _cuda_synchronize>
2347aten::copy_""",
2348        )
2349
2350    def test_profiler_name_pattern(self):
2351        x = torch.ones((4096, 4096))
2352        with profile() as prof:
2353            for _ in range(5):
2354                x = x @ x
2355                x = x + x
2356        matched_events = NamePattern(prof, "aten::mm").matched_events()
2357        output = "\n".join([f"{event.name}" for event in matched_events])
2358        self.assertExpectedInline(
2359            output,
2360            """\
2361aten::mm
2362aten::mm
2363aten::mm
2364aten::mm
2365aten::mm""",
2366        )
2367
2368    # TODO: Add logic for CUDA version of test
2369    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
2370    def test_profiler_pattern_match_helper(self):
2371        x = torch.ones((100, 100))
2372        with profile() as prof:
2373            for _ in range(5):
2374                x = x @ x
2375                x = x + x
2376        event_tree = prof.profiler.kineto_results.experimental_event_tree()
2377        pattern = Pattern(prof)
2378        self.assertEqual([], pattern.siblings_of(event_tree[0])[0])
2379        self.assertEqual(event_tree[1:], pattern.siblings_of(event_tree[0])[1])
2380        child_nodes = event_tree[0].children
2381        self.assertEqual([], pattern.siblings_of(child_nodes[0])[0])
2382        self.assertEqual(child_nodes[1:], pattern.siblings_of(child_nodes[0])[1])
2383        self.assertEqual(
2384            event_tree[0], pattern.root_of(event_tree[0].children[0].children[0])
2385        )
2386        self.assertEqual(None, pattern.next_of(event_tree[-1]))
2387        self.assertEqual(event_tree[1], pattern.next_of(event_tree[0]))
2388        self.assertEqual(event_tree[0], pattern.prev_of(event_tree[1]))
2389
2390    @unittest.skipIf(
2391        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
2392    )
2393    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2394    def test_profiler_extra_cuda_copy_pattern(self):
2395        cases = (
2396            (0, lambda: torch.ones((100, 100), device="cuda")),
2397            (1, lambda: torch.ones((100, 100)).to("cuda")),
2398            (1, lambda: torch.zeros((100, 100)).to("cuda")),
2399            (1, lambda: torch.empty((100, 100)).fill_(5).to("cuda")),
2400            (1, lambda: torch.ones((100, 100)).cuda()),
2401            (1, lambda: torch.zeros((100, 100)).cuda()),
2402            (1, lambda: torch.empty((100, 100)).fill_(5).cuda()),
2403            (1, lambda: torch.rand((100, 100)).cuda()),
2404            (1, lambda: torch.randn((100, 100)).cuda()),
2405            (1, lambda: torch.full((100, 100), 10).cuda()),
2406            (0, lambda: torch.rand((100, 100)).to(dtype=torch.float16)),
2407            (0, lambda: torch.rand((100, 100)).half()),
2408            (0, lambda: torch.rand((100, 100), device="cuda").half()),
2409        )
2410        num_matched = []
2411        for _, fn in cases:
2412            with profile(with_stack=True, record_shapes=True) as prof:
2413                fn()
2414            pattern = ExtraCUDACopyPattern(prof)
2415            num_matched.append(len(pattern.matched_events()))
2416        self.assertEqual(num_matched, [i for i, _ in cases])
2417
2418    @unittest.skipIf(
2419        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
2420    )
2421    def test_profiler_for_loop_indexing_pattern(self):
2422        x = torch.ones((100, 100))
2423
2424        def case1():
2425            for i in range(100):
2426                x[i] = i
2427
2428        def case2():
2429            y = 0
2430            for i in range(100):
2431                y += x[i]
2432
2433        def case3():
2434            y = 1
2435            for i in range(100):
2436                y *= x[i]
2437
2438        def case4():
2439            y = x
2440            for _ in range(100):
2441                y = y @ x
2442
2443        def case5():
2444            for i in range(100):
2445                x[i, :] = torch.arange(100) + i
2446
2447        cases = ((1, case1), (1, case2), (1, case3), (0, case4), (1, case5))
2448        num_matched = []
2449        for _, fn in cases:
2450            with profile(with_stack=True) as prof:
2451                fn()
2452            pattern = ForLoopIndexingPattern(prof)
2453            num_matched.append(len(pattern.matched_events()))
2454        self.assertEqual(num_matched, [i for i, _ in cases])
2455
2456    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2457    def test_profiler_fp32_matmul_pattern(self):
2458        x = torch.ones((100, 100), device="cuda")
2459        with profile(with_stack=True) as prof:
2460            x = x @ x
2461        pattern = FP32MatMulPattern(prof)
2462        has_tf32 = 0 if pattern.skip else 1
2463        num_matched = len(pattern.matched_events())
2464        self.assertEqual(num_matched, has_tf32)
2465
2466    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2467    def test_profiler_extra_cuda_copy_pattern_benchmark(self):
2468        with profile(with_stack=True, record_shapes=True) as prof:
2469            x = torch.ones((100, 100)).to("cuda")
2470            x = torch.ones((50, 50)).to("cuda")
2471        pattern = ExtraCUDACopyPattern(prof)
2472        shapes_factor_map = pattern.benchmark(pattern.matched_events())
2473        self.assertEqual(len(shapes_factor_map), 2)
2474
2475    def test_profiler_optimizer_single_tensor_pattern(self):
2476        x = torch.ones((100, 100))
2477        cases = (
2478            (1, lambda: torch.optim.Adam(model.parameters())),
2479            (1, lambda: torch.optim.SGD(model.parameters(), lr=0.01)),
2480            (1, lambda: torch.optim.AdamW(model.parameters())),
2481            (0, lambda: torch.optim.Adam(model.parameters(), foreach=True)),
2482            (0, lambda: torch.optim.SGD(model.parameters(), lr=0.01, foreach=True)),
2483            (0, lambda: torch.optim.AdamW(model.parameters(), foreach=True)),
2484        )
2485        num_matched = []
2486        for _, fn in cases:
2487            with profile(with_stack=True) as prof:
2488                model = nn.Sequential(
2489                    nn.Linear(100, 100),
2490                    nn.ReLU(),
2491                    nn.Linear(100, 10),
2492                )
2493                optimizer = fn()
2494                optimizer.zero_grad()
2495                y_hat = model(x)
2496                loss = torch.nn.functional.cross_entropy(
2497                    y_hat, torch.randint(0, 10, (100,))
2498                )
2499                loss.backward()
2500                optimizer.step()
2501            pattern = OptimizerSingleTensorPattern(prof)
2502            num_matched.append(len(pattern.matched_events()))
2503        self.assertEqual(num_matched, [i for i, _ in cases])
2504
2505    def test_profiler_synchronized_dataloader_pattern(self):
2506        dataset = torch.rand((100, 100))
2507        sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
2508        async_dataloader = torch.utils.data.DataLoader(
2509            dataset, batch_size=10, num_workers=4
2510        )
2511        with profile(with_stack=True) as prof:
2512            next(iter(sync_dataloader))
2513            next(iter(async_dataloader))
2514        pattern = SynchronizedDataLoaderPattern(prof)
2515        num_matched = len(pattern.matched_events())
2516        self.assertEqual(num_matched, 1)
2517
2518    @skipIfTorchDynamo(
2519        "pattern checks for aten::_zero op which might not be there with torch.compile'd graph"
2520    )
2521    def test_profiler_grad_not_set_to_none_pattern(self):
2522        x = torch.ones((100, 100))
2523        model = nn.Sequential(
2524            nn.Linear(100, 100),
2525            nn.ReLU(),
2526            nn.Linear(100, 10),
2527        )
2528        optimizer = torch.optim.Adam(model.parameters())
2529        cases = (
2530            (0, lambda: optimizer.zero_grad()),
2531            (0, lambda: model.zero_grad()),
2532            (1, lambda: optimizer.zero_grad(set_to_none=False)),
2533            (1, lambda: model.zero_grad(set_to_none=False)),
2534        )
2535        num_matched = []
2536        for _, fn in cases:
2537            with profile(with_stack=True) as prof:
2538                y_hat = model(x)
2539                loss = torch.nn.functional.cross_entropy(
2540                    y_hat, torch.randint(0, 10, (100,))
2541                )
2542                loss.backward()
2543                optimizer.step()
2544                fn()
2545            pattern = GradNotSetToNonePattern(prof)
2546            num_matched.append(len(pattern.matched_events()))
2547        self.assertEqual(num_matched, [i for i, _ in cases])
2548
2549    def test_profiler_conv2d_bias_followed_by_batchnorm2d_pattern(self):
2550        x = torch.randn((1, 3, 32, 32))
2551        cases = (
2552            (1, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1), nn.BatchNorm2d(3))),
2553            (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1, bias=False), nn.BatchNorm2d(3))),
2554            (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1))),
2555        )
2556        num_matched = []
2557        for _, model in cases:
2558            with profile(with_stack=True, record_shapes=True) as prof:
2559                model(x)
2560            pattern = Conv2dBiasFollowedByBatchNorm2dPattern(prof)
2561            num_matched.append(len(pattern.matched_events()))
2562        self.assertEqual(num_matched, [i for i, _ in cases])
2563
2564    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2565    def test_profiler_matmul_dim_fp16_pattern(self):
2566        cases = (
2567            (1, torch.randn((201, 201), device="cuda", dtype=torch.float16)),
2568            (1, torch.randn((3, 97, 97), device="cuda", dtype=torch.float16)),
2569            (0, torch.randn((200, 200), device="cuda", dtype=torch.float16)),
2570            (0, torch.randn((3, 200, 200), device="cuda", dtype=torch.float16)),
2571        )
2572        num_matched = []
2573        for _, x in cases:
2574            with profile(with_stack=True, record_shapes=True) as prof:
2575                x @ x
2576            pattern = MatMulDimInFP16Pattern(prof)
2577            num_matched.append(len(pattern.matched_events()))
2578        self.assertEqual(num_matched, [i for i, _ in cases])
2579
2580    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
2581    def test_profiler_pattern_matcher_json_report(self):
2582        x = torch.ones((100, 100))
2583        model = nn.Sequential(
2584            nn.Linear(100, 100),
2585            nn.ReLU(),
2586            nn.Linear(100, 10),
2587        )
2588        optimizer = torch.optim.Adam(model.parameters())
2589        with profile(with_stack=True, record_shapes=True) as prof:
2590            y_hat = model(x)
2591            loss = torch.nn.functional.cross_entropy(
2592                y_hat, torch.randint(0, 10, (100,))
2593            )
2594            loss.backward()
2595            optimizer.step()
2596            optimizer.zero_grad()
2597
2598        with tempfile.TemporaryDirectory() as tmpdir:
2599            report_all_anti_patterns(prof, json_report_dir=tmpdir, print_enable=False)
2600
2601            with open(os.path.join(tmpdir, "torchtidy_report.json")) as f:
2602                report = json.load(f)
2603
2604            # It is platform dependent whether the path will include "profiler/"
2605            keys = [k for k in report.keys() if k.endswith("test_profiler.py")]
2606            self.assertEqual(len(keys), 1, f"{keys}")
2607            entry = report[keys[0]]
2608
2609            self.assertTrue(len(entry) > 0)
2610            expected_fields = sorted(["line_number", "name", "url", "message"])
2611            for event in entry:
2612                actual_fields = sorted(event.keys())
2613                self.assertEqual(expected_fields, actual_fields)
2614
2615    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
2616    def test_fuzz_symbolize(self):
2617        # generate some random addresses in the text section and make sure the
2618        # symbolizers do not throw exceptions/crash
2619        def get_text_sections():
2620            text_sections = []
2621            seen = set()
2622            for filename in os.listdir("/proc/self/map_files"):
2623                library = os.readlink("/proc/self/map_files/" + filename)
2624                if ".so" not in library or library in seen:
2625                    continue
2626                seen.add(library)
2627                with open(os.path.join("/proc/self/map_files", library), "rb") as f:
2628                    mm = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ)
2629
2630                    def unpack(fmt, offset):
2631                        return struct.unpack(
2632                            fmt, mm[offset : offset + struct.calcsize(fmt)]
2633                        )
2634
2635                    if mm[:4] != b"\x7fELF":
2636                        continue
2637                    (section_headers_start,) = unpack("Q", 40)
2638                    (section_header_size,) = unpack("H", 58)
2639                    (num_section_headers,) = unpack("H", 60)
2640                    (shstrndx,) = unpack("H", 62)
2641                    (shstrtab_offset,) = unpack(
2642                        "Q", section_headers_start + shstrndx * section_header_size + 24
2643                    )
2644                    for i in range(num_section_headers):
2645                        (section_name_offset,) = unpack(
2646                            "I", section_headers_start + i * section_header_size
2647                        )
2648                        name_start = shstrtab_offset + section_name_offset
2649                        section_name = mm[name_start : name_start + 6]
2650                        if section_name != b".text\0":
2651                            continue
2652                        (section_offset,) = unpack(
2653                            "Q", section_headers_start + i * section_header_size + 24
2654                        )
2655                        (section_size,) = unpack(
2656                            "Q", section_headers_start + i * section_header_size + 32
2657                        )
2658                        start = int(filename.split("-")[0], 16) + section_offset
2659                        text_sections.append((start, section_size))
2660                        break
2661                    mm.close()
2662            return text_sections
2663
2664        r = random.Random()
2665        r.seed(1)
2666        text_sections = get_text_sections()
2667        addrs = []
2668        for i in range(200):
2669            s = r.randrange(0, len(text_sections))
2670            start, size = text_sections[s]
2671            addr = r.randrange(start, start + size)
2672            addrs.append(addr)
2673        fast = torch._C._profiler.symbolize_addresses(addrs, "fast")
2674        dladdr = torch._C._profiler.symbolize_addresses(addrs, "dladdr")
2675        addr2line = torch._C._profiler.symbolize_addresses(addrs, "addr2line")
2676        self.assertEqual(len(fast), len(addrs))
2677        self.assertEqual(len(addr2line), len(fast))
2678
2679
2680if __name__ == "__main__":
2681    run_tests()
2682