1# Owner(s): ["oncall: r2p"] 2 3import tempfile 4import time 5 6from datetime import datetime, timedelta 7 8from torch.monitor import ( 9 Aggregation, 10 Event, 11 log_event, 12 register_event_handler, 13 Stat, 14 TensorboardEventHandler, 15 unregister_event_handler, 16 _WaitCounter, 17) 18from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 19 20class TestMonitor(TestCase): 21 def test_interval_stat(self) -> None: 22 events = [] 23 24 def handler(event): 25 events.append(event) 26 27 handle = register_event_handler(handler) 28 s = Stat( 29 "asdf", 30 (Aggregation.SUM, Aggregation.COUNT), 31 timedelta(milliseconds=1), 32 ) 33 self.assertEqual(s.name, "asdf") 34 35 s.add(2) 36 for _ in range(100): 37 # NOTE: different platforms sleep may be inaccurate so we loop 38 # instead (i.e. win) 39 time.sleep(1 / 1000) # ms 40 s.add(3) 41 if len(events) >= 1: 42 break 43 self.assertGreaterEqual(len(events), 1) 44 unregister_event_handler(handle) 45 46 def test_fixed_count_stat(self) -> None: 47 s = Stat( 48 "asdf", 49 (Aggregation.SUM, Aggregation.COUNT), 50 timedelta(hours=100), 51 3, 52 ) 53 s.add(1) 54 s.add(2) 55 name = s.name 56 self.assertEqual(name, "asdf") 57 self.assertEqual(s.count, 2) 58 s.add(3) 59 self.assertEqual(s.count, 0) 60 self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3}) 61 62 def test_log_event(self) -> None: 63 e = Event( 64 name="torch.monitor.TestEvent", 65 timestamp=datetime.now(), 66 data={ 67 "str": "a string", 68 "float": 1234.0, 69 "int": 1234, 70 }, 71 ) 72 self.assertEqual(e.name, "torch.monitor.TestEvent") 73 self.assertIsNotNone(e.timestamp) 74 self.assertIsNotNone(e.data) 75 log_event(e) 76 77 @skipIfTorchDynamo("Really weird error") 78 def test_event_handler(self) -> None: 79 events = [] 80 81 def handler(event: Event) -> None: 82 events.append(event) 83 84 handle = register_event_handler(handler) 85 e = Event( 86 name="torch.monitor.TestEvent", 87 timestamp=datetime.now(), 88 data={}, 89 ) 90 log_event(e) 91 self.assertEqual(len(events), 1) 92 self.assertEqual(events[0], e) 93 log_event(e) 94 self.assertEqual(len(events), 2) 95 96 unregister_event_handler(handle) 97 log_event(e) 98 self.assertEqual(len(events), 2) 99 100 def test_wait_counter(self) -> None: 101 wait_counter = _WaitCounter( 102 "test_wait_counter", 103 ) 104 with wait_counter.guard() as wcg: 105 pass 106 107 108@skipIfTorchDynamo("Really weird error") 109class TestMonitorTensorboard(TestCase): 110 def setUp(self): 111 global SummaryWriter, event_multiplexer 112 try: 113 from torch.utils.tensorboard import SummaryWriter 114 from tensorboard.backend.event_processing import ( 115 plugin_event_multiplexer as event_multiplexer, 116 ) 117 except ImportError: 118 return self.skipTest("Skip the test since TensorBoard is not installed") 119 self.temp_dirs = [] 120 121 def create_summary_writer(self): 122 temp_dir = tempfile.TemporaryDirectory() # noqa: P201 123 self.temp_dirs.append(temp_dir) 124 return SummaryWriter(temp_dir.name) 125 126 def tearDown(self): 127 # Remove directories created by SummaryWriter 128 for temp_dir in self.temp_dirs: 129 temp_dir.cleanup() 130 131 def test_event_handler(self): 132 with self.create_summary_writer() as w: 133 handle = register_event_handler(TensorboardEventHandler(w)) 134 135 s = Stat( 136 "asdf", 137 (Aggregation.SUM, Aggregation.COUNT), 138 timedelta(hours=1), 139 5, 140 ) 141 for i in range(10): 142 s.add(i) 143 self.assertEqual(s.count, 0) 144 145 unregister_event_handler(handle) 146 147 mul = event_multiplexer.EventMultiplexer() 148 mul.AddRunsFromDirectory(self.temp_dirs[-1].name) 149 mul.Reload() 150 scalar_dict = mul.PluginRunToTagToContent("scalars") 151 raw_result = { 152 tag: mul.Tensors(run, tag) 153 for run, run_dict in scalar_dict.items() 154 for tag in run_dict 155 } 156 scalars = { 157 tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items() 158 } 159 self.assertEqual(scalars, { 160 "asdf.sum": [10], 161 "asdf.count": [5], 162 }) 163 164 165if __name__ == '__main__': 166 run_tests() 167