# Owner(s): ["oncall: jit"] import os import sys import torch # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestLogging(JitTestCase): def test_bump_numeric_counter(self): class ModuleThatLogs(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): for i in range(x.size(0)): x += 1.0 torch.jit._logging.add_stat_value("foo", 1) if bool(x.sum() > 0.0): torch.jit._logging.add_stat_value("positive", 1) else: torch.jit._logging.add_stat_value("negative", 1) return x logger = torch.jit._logging.LockingLogger() old_logger = torch.jit._logging.set_logger(logger) try: mtl = ModuleThatLogs() for i in range(5): mtl(torch.rand(3, 4, 5)) self.assertEqual(logger.get_counter_val("foo"), 15) self.assertEqual(logger.get_counter_val("positive"), 5) finally: torch.jit._logging.set_logger(old_logger) def test_trace_numeric_counter(self): def foo(x): torch.jit._logging.add_stat_value("foo", 1) return x + 1.0 traced = torch.jit.trace(foo, torch.rand(3, 4)) logger = torch.jit._logging.LockingLogger() old_logger = torch.jit._logging.set_logger(logger) try: traced(torch.rand(3, 4)) self.assertEqual(logger.get_counter_val("foo"), 1) finally: torch.jit._logging.set_logger(old_logger) def test_time_measurement_counter(self): class ModuleThatTimes(torch.jit.ScriptModule): def forward(self, x): tp_start = torch.jit._logging.time_point() for i in range(30): x += 1.0 tp_end = torch.jit._logging.time_point() torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start) return x mtm = ModuleThatTimes() logger = torch.jit._logging.LockingLogger() old_logger = torch.jit._logging.set_logger(logger) try: mtm(torch.rand(3, 4)) self.assertGreater(logger.get_counter_val("mytimer"), 0) finally: torch.jit._logging.set_logger(old_logger) def test_time_measurement_counter_script(self): class ModuleThatTimes(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): tp_start = torch.jit._logging.time_point() for i in range(30): x += 1.0 tp_end = torch.jit._logging.time_point() torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start) return x mtm = ModuleThatTimes() logger = torch.jit._logging.LockingLogger() old_logger = torch.jit._logging.set_logger(logger) try: mtm(torch.rand(3, 4)) self.assertGreater(logger.get_counter_val("mytimer"), 0) finally: torch.jit._logging.set_logger(old_logger) def test_counter_aggregation(self): def foo(x): for i in range(3): torch.jit._logging.add_stat_value("foo", 1) return x + 1.0 traced = torch.jit.trace(foo, torch.rand(3, 4)) logger = torch.jit._logging.LockingLogger() logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG) old_logger = torch.jit._logging.set_logger(logger) try: traced(torch.rand(3, 4)) self.assertEqual(logger.get_counter_val("foo"), 1) finally: torch.jit._logging.set_logger(old_logger) def test_logging_levels_set(self): torch._C._jit_set_logging_option("foo") self.assertEqual("foo", torch._C._jit_get_logging_option())