xref: /aosp_15_r20/external/pytorch/test/profiler/test_record_function.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: profiler"]
2
3# if tqdm is not shutdown properly, it will leave the monitor thread alive.
4# This causes an issue in the multithreading test because we check all events
5# in that test with their tids. The events that correspond to these lingering
6# threads all have TID of (uint64_t)(-1) which is invalid.
7# The work around is turnning off monitoring thread when tqdm is loaded.
8# Since these are unit tests, it is safe to turn off monitor thread.
9try:
10    import tqdm
11
12    tqdm.tqdm.monitor_interval = 0
13except ImportError:
14    None
15
16from typing import Any, Dict
17
18import torch
19import torch.optim
20import torch.utils.data
21import torch.utils.data.datapipes as dp
22from torch.autograd import (
23    _record_function_with_args_enter,
24    _record_function_with_args_exit,
25)
26from torch.autograd.profiler import profile as _profile
27from torch.profiler import kineto_available, record_function
28from torch.testing._internal.common_utils import run_tests, TestCase
29
30
31Json = Dict[str, Any]
32
33
34class TestRecordFunction(TestCase):
35    def _record_function_with_param(self):
36        u = torch.randn(3, 4, 5, requires_grad=True)
37        with _profile(
38            with_stack=True, use_kineto=kineto_available(), record_shapes=True
39        ) as prof:
40            with record_function("## TEST 1 ##", "1, 2, 3"):
41                rf_handle = _record_function_with_args_enter(
42                    "## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u
43                )
44                _record_function_with_args_exit(rf_handle)
45            with record_function("## TEST 3 ##"):
46                rf_handle = _record_function_with_args_enter("## TEST 4 ##")
47                _record_function_with_args_exit(rf_handle)
48        return prof
49
50    def test_record_function(self):
51        prof_result = self._record_function_with_param()
52        found_test_1 = False
53        found_test_2 = False
54        found_test_3 = False
55        found_test_4 = False
56        for e in prof_result.function_events:
57            if "## TEST 1 ##" == e.name:
58                found_test_1 = True
59                self.assertTrue(e.input_shapes == [[]])
60            elif "## TEST 2 ##" == e.name:
61                found_test_2 = True
62                self.assertTrue(e.input_shapes == [[], [], [], [], [], [3, 4, 5]])
63            elif "## TEST 3 ##" == e.name:
64                found_test_3 = True
65                self.assertTrue(e.input_shapes == [])
66            elif "## TEST 4 ##" == e.name:
67                found_test_4 = True
68                self.assertTrue(e.input_shapes == [])
69        self.assertTrue(found_test_1)
70        self.assertTrue(found_test_2)
71        self.assertTrue(found_test_3)
72        self.assertTrue(found_test_4)
73
74    def test_datapipe_with_record_function(self):
75        with _profile(
76            with_stack=True, use_kineto=kineto_available(), record_shapes=True
77        ) as prof:
78            input_dp1 = dp.iter.IterableWrapper(range(4))
79            input_dp2 = dp.iter.IterableWrapper(range(4, 8))
80            input_dp3 = dp.iter.IterableWrapper(range(8, 12))
81            output_dp = input_dp1.mux(input_dp2, input_dp3)
82            output = list(output_dp)
83
84        has_iter = False
85        has_mux = False
86        for e in prof.function_events:
87            if has_iter and has_mux:
88                break
89
90            if not has_iter and "IterableWrapper" in e.name:
91                has_iter = True
92            if not has_mux and "Multiplexer" in e.name:
93                has_mux = True
94        self.assertTrue(has_iter)
95        self.assertTrue(has_mux)
96
97    def test_datapipe_delegation_with_profiler(self):
98        class IDPIterator(torch.utils.data.IterDataPipe):
99            def __init__(self) -> None:
100                self.data = list(range(10))
101                self._idx = 0
102
103            def __iter__(self):
104                return self
105
106            def __next__(self):
107                if self._idx >= 10:
108                    self._idx = 0
109                    raise StopIteration
110                self._idx += 1
111                return self.data[self._idx - 1]
112
113            def get_value(self, idx):
114                return self.data[idx]
115
116        dp1 = IDPIterator()  # The object itself is an iterator
117        self.assertEqual(5, dp1.get_value(5))
118        it_dp1 = iter(dp1)  # This creates the 1st iterator
119        self.assertEqual(5, it_dp1.get_value(5))  # type: ignore[attr-defined]
120        self.assertEqual(list(range(10)), list(it_dp1))
121
122        class IDPDelegator(torch.utils.data.IterDataPipe):
123            def __init__(self, datapipe):
124                self.datapipe = datapipe
125
126            def __iter__(self):
127                return iter(self.datapipe)
128
129        dp2 = IDPDelegator(dp1)
130        it_dp2 = iter(dp2)
131        self.assertEqual(5, it_dp2.get_value(5))
132        self.assertEqual(list(range(10)), list(it_dp2))
133
134    def test_datapipe_with_record_function_fork(self):
135        with _profile(
136            with_stack=True, use_kineto=kineto_available(), record_shapes=True
137        ) as prof:
138            input_dp = dp.iter.IterableWrapper(range(10))
139            dp1, dp2, dp3 = input_dp.fork(num_instances=3)
140            output1 = list(dp1)
141        has_iter = False
142        has_child = False
143        for e in prof.function_events:
144            if has_iter and has_child:
145                break
146
147            if not has_iter and "IterableWrapper" in e.name:
148                has_iter = True
149            if not has_child and "_ChildDataPipe" in e.name:
150                has_child = True
151        self.assertTrue(has_iter)
152        self.assertTrue(has_child)
153
154
155if __name__ == "__main__":
156    run_tests()
157