1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5 6import torch 7 8 9# Make the helper files in test/ importable 10pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11sys.path.append(pytorch_test_dir) 12from torch.testing._internal.jit_utils import JitTestCase 13 14 15if __name__ == "__main__": 16 raise RuntimeError( 17 "This test file is not meant to be run directly, use:\n\n" 18 "\tpython test/test_jit.py TESTNAME\n\n" 19 "instead." 20 ) 21 22 23class TestLogging(JitTestCase): 24 def test_bump_numeric_counter(self): 25 class ModuleThatLogs(torch.jit.ScriptModule): 26 @torch.jit.script_method 27 def forward(self, x): 28 for i in range(x.size(0)): 29 x += 1.0 30 torch.jit._logging.add_stat_value("foo", 1) 31 32 if bool(x.sum() > 0.0): 33 torch.jit._logging.add_stat_value("positive", 1) 34 else: 35 torch.jit._logging.add_stat_value("negative", 1) 36 return x 37 38 logger = torch.jit._logging.LockingLogger() 39 old_logger = torch.jit._logging.set_logger(logger) 40 try: 41 mtl = ModuleThatLogs() 42 for i in range(5): 43 mtl(torch.rand(3, 4, 5)) 44 45 self.assertEqual(logger.get_counter_val("foo"), 15) 46 self.assertEqual(logger.get_counter_val("positive"), 5) 47 finally: 48 torch.jit._logging.set_logger(old_logger) 49 50 def test_trace_numeric_counter(self): 51 def foo(x): 52 torch.jit._logging.add_stat_value("foo", 1) 53 return x + 1.0 54 55 traced = torch.jit.trace(foo, torch.rand(3, 4)) 56 logger = torch.jit._logging.LockingLogger() 57 old_logger = torch.jit._logging.set_logger(logger) 58 try: 59 traced(torch.rand(3, 4)) 60 61 self.assertEqual(logger.get_counter_val("foo"), 1) 62 finally: 63 torch.jit._logging.set_logger(old_logger) 64 65 def test_time_measurement_counter(self): 66 class ModuleThatTimes(torch.jit.ScriptModule): 67 def forward(self, x): 68 tp_start = torch.jit._logging.time_point() 69 for i in range(30): 70 x += 1.0 71 tp_end = torch.jit._logging.time_point() 72 torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start) 73 return x 74 75 mtm = ModuleThatTimes() 76 logger = torch.jit._logging.LockingLogger() 77 old_logger = torch.jit._logging.set_logger(logger) 78 try: 79 mtm(torch.rand(3, 4)) 80 self.assertGreater(logger.get_counter_val("mytimer"), 0) 81 finally: 82 torch.jit._logging.set_logger(old_logger) 83 84 def test_time_measurement_counter_script(self): 85 class ModuleThatTimes(torch.jit.ScriptModule): 86 @torch.jit.script_method 87 def forward(self, x): 88 tp_start = torch.jit._logging.time_point() 89 for i in range(30): 90 x += 1.0 91 tp_end = torch.jit._logging.time_point() 92 torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start) 93 return x 94 95 mtm = ModuleThatTimes() 96 logger = torch.jit._logging.LockingLogger() 97 old_logger = torch.jit._logging.set_logger(logger) 98 try: 99 mtm(torch.rand(3, 4)) 100 self.assertGreater(logger.get_counter_val("mytimer"), 0) 101 finally: 102 torch.jit._logging.set_logger(old_logger) 103 104 def test_counter_aggregation(self): 105 def foo(x): 106 for i in range(3): 107 torch.jit._logging.add_stat_value("foo", 1) 108 return x + 1.0 109 110 traced = torch.jit.trace(foo, torch.rand(3, 4)) 111 logger = torch.jit._logging.LockingLogger() 112 logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG) 113 old_logger = torch.jit._logging.set_logger(logger) 114 try: 115 traced(torch.rand(3, 4)) 116 117 self.assertEqual(logger.get_counter_val("foo"), 1) 118 finally: 119 torch.jit._logging.set_logger(old_logger) 120 121 def test_logging_levels_set(self): 122 torch._C._jit_set_logging_option("foo") 123 self.assertEqual("foo", torch._C._jit_get_logging_option()) 124