xref: /aosp_15_r20/external/pytorch/test/test_monitor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3import tempfile
4import time
5
6from datetime import datetime, timedelta
7
8from torch.monitor import (
9    Aggregation,
10    Event,
11    log_event,
12    register_event_handler,
13    Stat,
14    TensorboardEventHandler,
15    unregister_event_handler,
16    _WaitCounter,
17)
18from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
19
20class TestMonitor(TestCase):
21    def test_interval_stat(self) -> None:
22        events = []
23
24        def handler(event):
25            events.append(event)
26
27        handle = register_event_handler(handler)
28        s = Stat(
29            "asdf",
30            (Aggregation.SUM, Aggregation.COUNT),
31            timedelta(milliseconds=1),
32        )
33        self.assertEqual(s.name, "asdf")
34
35        s.add(2)
36        for _ in range(100):
37            # NOTE: different platforms sleep may be inaccurate so we loop
38            # instead (i.e. win)
39            time.sleep(1 / 1000)  # ms
40            s.add(3)
41            if len(events) >= 1:
42                break
43        self.assertGreaterEqual(len(events), 1)
44        unregister_event_handler(handle)
45
46    def test_fixed_count_stat(self) -> None:
47        s = Stat(
48            "asdf",
49            (Aggregation.SUM, Aggregation.COUNT),
50            timedelta(hours=100),
51            3,
52        )
53        s.add(1)
54        s.add(2)
55        name = s.name
56        self.assertEqual(name, "asdf")
57        self.assertEqual(s.count, 2)
58        s.add(3)
59        self.assertEqual(s.count, 0)
60        self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
61
62    def test_log_event(self) -> None:
63        e = Event(
64            name="torch.monitor.TestEvent",
65            timestamp=datetime.now(),
66            data={
67                "str": "a string",
68                "float": 1234.0,
69                "int": 1234,
70            },
71        )
72        self.assertEqual(e.name, "torch.monitor.TestEvent")
73        self.assertIsNotNone(e.timestamp)
74        self.assertIsNotNone(e.data)
75        log_event(e)
76
77    @skipIfTorchDynamo("Really weird error")
78    def test_event_handler(self) -> None:
79        events = []
80
81        def handler(event: Event) -> None:
82            events.append(event)
83
84        handle = register_event_handler(handler)
85        e = Event(
86            name="torch.monitor.TestEvent",
87            timestamp=datetime.now(),
88            data={},
89        )
90        log_event(e)
91        self.assertEqual(len(events), 1)
92        self.assertEqual(events[0], e)
93        log_event(e)
94        self.assertEqual(len(events), 2)
95
96        unregister_event_handler(handle)
97        log_event(e)
98        self.assertEqual(len(events), 2)
99
100    def test_wait_counter(self) -> None:
101        wait_counter = _WaitCounter(
102            "test_wait_counter",
103        )
104        with wait_counter.guard() as wcg:
105            pass
106
107
108@skipIfTorchDynamo("Really weird error")
109class TestMonitorTensorboard(TestCase):
110    def setUp(self):
111        global SummaryWriter, event_multiplexer
112        try:
113            from torch.utils.tensorboard import SummaryWriter
114            from tensorboard.backend.event_processing import (
115                plugin_event_multiplexer as event_multiplexer,
116            )
117        except ImportError:
118            return self.skipTest("Skip the test since TensorBoard is not installed")
119        self.temp_dirs = []
120
121    def create_summary_writer(self):
122        temp_dir = tempfile.TemporaryDirectory()  # noqa: P201
123        self.temp_dirs.append(temp_dir)
124        return SummaryWriter(temp_dir.name)
125
126    def tearDown(self):
127        # Remove directories created by SummaryWriter
128        for temp_dir in self.temp_dirs:
129            temp_dir.cleanup()
130
131    def test_event_handler(self):
132        with self.create_summary_writer() as w:
133            handle = register_event_handler(TensorboardEventHandler(w))
134
135            s = Stat(
136                "asdf",
137                (Aggregation.SUM, Aggregation.COUNT),
138                timedelta(hours=1),
139                5,
140            )
141            for i in range(10):
142                s.add(i)
143            self.assertEqual(s.count, 0)
144
145            unregister_event_handler(handle)
146
147        mul = event_multiplexer.EventMultiplexer()
148        mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
149        mul.Reload()
150        scalar_dict = mul.PluginRunToTagToContent("scalars")
151        raw_result = {
152            tag: mul.Tensors(run, tag)
153            for run, run_dict in scalar_dict.items()
154            for tag in run_dict
155        }
156        scalars = {
157            tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
158        }
159        self.assertEqual(scalars, {
160            "asdf.sum": [10],
161            "asdf.count": [5],
162        })
163
164
165if __name__ == '__main__':
166    run_tests()
167