1# mypy: ignore-errors 2 3import torch._dynamo.test_case 4import unittest.mock 5import os 6import contextlib 7import torch._logging 8import torch._logging._internal 9from torch._dynamo.utils import LazyString 10from torch._inductor import config as inductor_config 11import logging 12import io 13 14@contextlib.contextmanager 15def preserve_log_state(): 16 prev_state = torch._logging._internal._get_log_state() 17 torch._logging._internal._set_log_state(torch._logging._internal.LogState()) 18 try: 19 yield 20 finally: 21 torch._logging._internal._set_log_state(prev_state) 22 torch._logging._internal._init_logs() 23 24def log_settings(settings): 25 exit_stack = contextlib.ExitStack() 26 settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings}) 27 exit_stack.enter_context(preserve_log_state()) 28 exit_stack.enter_context(settings_patch) 29 torch._logging._internal._init_logs() 30 return exit_stack 31 32def log_api(**kwargs): 33 exit_stack = contextlib.ExitStack() 34 exit_stack.enter_context(preserve_log_state()) 35 torch._logging.set_logs(**kwargs) 36 return exit_stack 37 38 39def kwargs_to_settings(**kwargs): 40 INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"} 41 42 settings = [] 43 44 def append_setting(name, level): 45 if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY: 46 settings.append(INT_TO_VERBOSITY[level] + name) 47 return 48 else: 49 raise ValueError("Invalid value for setting") 50 51 for name, val in kwargs.items(): 52 if isinstance(val, bool): 53 settings.append(name) 54 elif isinstance(val, int): 55 append_setting(name, val) 56 elif isinstance(val, dict) and name == "modules": 57 for module_qname, level in val.items(): 58 append_setting(module_qname, level) 59 else: 60 raise ValueError("Invalid value for setting") 61 62 return ",".join(settings) 63 64 65# Note on testing strategy: 66# This class does two things: 67# 1. Runs two versions of a test: 68# 1a. patches the env var log settings to some specific value 69# 1b. calls torch._logging.set_logs(..) 70# 2. patches the emit method of each setup handler to gather records 71# that are emitted to each console stream 72# 3. passes a ref to the gathered records to each test case for checking 73# 74# The goal of this testing in general is to ensure that given some settings env var 75# that the logs are setup correctly and capturing the correct records. 76def make_logging_test(**kwargs): 77 def wrapper(fn): 78 @inductor_config.patch({"fx_graph_cache": False}) 79 def test_fn(self): 80 81 torch._dynamo.reset() 82 records = [] 83 # run with env var 84 if len(kwargs) == 0: 85 with self._handler_watcher(records): 86 fn(self, records) 87 else: 88 with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records): 89 fn(self, records) 90 91 # run with API 92 torch._dynamo.reset() 93 records.clear() 94 with log_api(**kwargs), self._handler_watcher(records): 95 fn(self, records) 96 97 98 return test_fn 99 100 return wrapper 101 102def make_settings_test(settings): 103 def wrapper(fn): 104 def test_fn(self): 105 torch._dynamo.reset() 106 records = [] 107 # run with env var 108 with log_settings(settings), self._handler_watcher(records): 109 fn(self, records) 110 111 return test_fn 112 113 return wrapper 114 115class LoggingTestCase(torch._dynamo.test_case.TestCase): 116 @classmethod 117 def setUpClass(cls): 118 super().setUpClass() 119 cls._exit_stack.enter_context( 120 unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""}) 121 ) 122 cls._exit_stack.enter_context( 123 unittest.mock.patch("torch._dynamo.config.suppress_errors", True) 124 ) 125 cls._exit_stack.enter_context( 126 unittest.mock.patch("torch._dynamo.config.verbose", False) 127 ) 128 129 @classmethod 130 def tearDownClass(cls): 131 cls._exit_stack.close() 132 torch._logging._internal.log_state.clear() 133 torch._logging._init_logs() 134 135 def hasRecord(self, records, m): 136 return any(m in r.getMessage() for r in records) 137 138 def getRecord(self, records, m): 139 record = None 140 for r in records: 141 # NB: not r.msg because it looks like 3.11 changed how they 142 # structure log records 143 if m in r.getMessage(): 144 self.assertIsNone( 145 record, 146 msg=LazyString( 147 lambda: f"multiple matching records: {record} and {r} among {records}" 148 ), 149 ) 150 record = r 151 if record is None: 152 self.fail(f"did not find record with {m} among {records}") 153 return record 154 155 # This patches the emit method of each handler to gather records 156 # as they are emitted 157 def _handler_watcher(self, record_list): 158 exit_stack = contextlib.ExitStack() 159 160 def emit_post_hook(record): 161 nonlocal record_list 162 record_list.append(record) 163 164 # registered logs are the only ones with handlers, so patch those 165 for log_qname in torch._logging._internal.log_registry.get_log_qnames(): 166 logger = logging.getLogger(log_qname) 167 num_handlers = len(logger.handlers) 168 self.assertLessEqual( 169 num_handlers, 170 2, 171 "All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).", 172 ) 173 174 self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers") 175 176 for handler in logger.handlers: 177 old_emit = handler.emit 178 179 def new_emit(record): 180 old_emit(record) 181 emit_post_hook(record) 182 183 exit_stack.enter_context( 184 unittest.mock.patch.object(handler, "emit", new_emit) 185 ) 186 187 return exit_stack 188 189 190def logs_to_string(module, log_option): 191 """Example: 192 logs_to_string("torch._inductor.compile_fx", "post_grad_graphs") 193 returns the output of TORCH_LOGS="post_grad_graphs" from the 194 torch._inductor.compile_fx module. 195 """ 196 log_stream = io.StringIO() 197 handler = logging.StreamHandler(stream=log_stream) 198 199 @contextlib.contextmanager 200 def tmp_redirect_logs(): 201 try: 202 logger = torch._logging.getArtifactLogger(module, log_option) 203 logger.addHandler(handler) 204 yield 205 finally: 206 logger.removeHandler(handler) 207 208 def ctx_manager(): 209 exit_stack = log_settings(log_option) 210 exit_stack.enter_context(tmp_redirect_logs()) 211 return exit_stack 212 213 return log_stream, ctx_manager 214