xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/logging_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch._dynamo.test_case
4import unittest.mock
5import os
6import contextlib
7import torch._logging
8import torch._logging._internal
9from torch._dynamo.utils import LazyString
10from torch._inductor import config as inductor_config
11import logging
12import io
13
14@contextlib.contextmanager
15def preserve_log_state():
16    prev_state = torch._logging._internal._get_log_state()
17    torch._logging._internal._set_log_state(torch._logging._internal.LogState())
18    try:
19        yield
20    finally:
21        torch._logging._internal._set_log_state(prev_state)
22        torch._logging._internal._init_logs()
23
24def log_settings(settings):
25    exit_stack = contextlib.ExitStack()
26    settings_patch = unittest.mock.patch.dict(os.environ, {"TORCH_LOGS": settings})
27    exit_stack.enter_context(preserve_log_state())
28    exit_stack.enter_context(settings_patch)
29    torch._logging._internal._init_logs()
30    return exit_stack
31
32def log_api(**kwargs):
33    exit_stack = contextlib.ExitStack()
34    exit_stack.enter_context(preserve_log_state())
35    torch._logging.set_logs(**kwargs)
36    return exit_stack
37
38
39def kwargs_to_settings(**kwargs):
40    INT_TO_VERBOSITY = {10: "+", 20: "", 40: "-"}
41
42    settings = []
43
44    def append_setting(name, level):
45        if isinstance(name, str) and isinstance(level, int) and level in INT_TO_VERBOSITY:
46            settings.append(INT_TO_VERBOSITY[level] + name)
47            return
48        else:
49            raise ValueError("Invalid value for setting")
50
51    for name, val in kwargs.items():
52        if isinstance(val, bool):
53            settings.append(name)
54        elif isinstance(val, int):
55            append_setting(name, val)
56        elif isinstance(val, dict) and name == "modules":
57            for module_qname, level in val.items():
58                append_setting(module_qname, level)
59        else:
60            raise ValueError("Invalid value for setting")
61
62    return ",".join(settings)
63
64
65# Note on testing strategy:
66# This class does two things:
67# 1. Runs two versions of a test:
68#    1a. patches the env var log settings to some specific value
69#    1b. calls torch._logging.set_logs(..)
70# 2. patches the emit method of each setup handler to gather records
71# that are emitted to each console stream
72# 3. passes a ref to the gathered records to each test case for checking
73#
74# The goal of this testing in general is to ensure that given some settings env var
75# that the logs are setup correctly and capturing the correct records.
76def make_logging_test(**kwargs):
77    def wrapper(fn):
78        @inductor_config.patch({"fx_graph_cache": False})
79        def test_fn(self):
80
81            torch._dynamo.reset()
82            records = []
83            # run with env var
84            if len(kwargs) == 0:
85                with self._handler_watcher(records):
86                    fn(self, records)
87            else:
88                with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
89                    fn(self, records)
90
91            # run with API
92            torch._dynamo.reset()
93            records.clear()
94            with log_api(**kwargs), self._handler_watcher(records):
95                fn(self, records)
96
97
98        return test_fn
99
100    return wrapper
101
102def make_settings_test(settings):
103    def wrapper(fn):
104        def test_fn(self):
105            torch._dynamo.reset()
106            records = []
107            # run with env var
108            with log_settings(settings), self._handler_watcher(records):
109                fn(self, records)
110
111        return test_fn
112
113    return wrapper
114
115class LoggingTestCase(torch._dynamo.test_case.TestCase):
116    @classmethod
117    def setUpClass(cls):
118        super().setUpClass()
119        cls._exit_stack.enter_context(
120            unittest.mock.patch.dict(os.environ, {"___LOG_TESTING": ""})
121        )
122        cls._exit_stack.enter_context(
123            unittest.mock.patch("torch._dynamo.config.suppress_errors", True)
124        )
125        cls._exit_stack.enter_context(
126            unittest.mock.patch("torch._dynamo.config.verbose", False)
127        )
128
129    @classmethod
130    def tearDownClass(cls):
131        cls._exit_stack.close()
132        torch._logging._internal.log_state.clear()
133        torch._logging._init_logs()
134
135    def hasRecord(self, records, m):
136        return any(m in r.getMessage() for r in records)
137
138    def getRecord(self, records, m):
139        record = None
140        for r in records:
141            # NB: not r.msg because it looks like 3.11 changed how they
142            # structure log records
143            if m in r.getMessage():
144                self.assertIsNone(
145                    record,
146                    msg=LazyString(
147                        lambda: f"multiple matching records: {record} and {r} among {records}"
148                    ),
149                )
150                record = r
151        if record is None:
152            self.fail(f"did not find record with {m} among {records}")
153        return record
154
155    # This patches the emit method of each handler to gather records
156    # as they are emitted
157    def _handler_watcher(self, record_list):
158        exit_stack = contextlib.ExitStack()
159
160        def emit_post_hook(record):
161            nonlocal record_list
162            record_list.append(record)
163
164        # registered logs are the only ones with handlers, so patch those
165        for log_qname in torch._logging._internal.log_registry.get_log_qnames():
166            logger = logging.getLogger(log_qname)
167            num_handlers = len(logger.handlers)
168            self.assertLessEqual(
169                num_handlers,
170                2,
171                "All pt2 loggers should only have at most two handlers (debug artifacts and messages above debug level).",
172            )
173
174            self.assertGreater(num_handlers, 0, "All pt2 loggers should have more than zero handlers")
175
176            for handler in logger.handlers:
177                old_emit = handler.emit
178
179                def new_emit(record):
180                    old_emit(record)
181                    emit_post_hook(record)
182
183                exit_stack.enter_context(
184                    unittest.mock.patch.object(handler, "emit", new_emit)
185                )
186
187        return exit_stack
188
189
190def logs_to_string(module, log_option):
191    """Example:
192    logs_to_string("torch._inductor.compile_fx", "post_grad_graphs")
193    returns the output of TORCH_LOGS="post_grad_graphs" from the
194    torch._inductor.compile_fx module.
195    """
196    log_stream = io.StringIO()
197    handler = logging.StreamHandler(stream=log_stream)
198
199    @contextlib.contextmanager
200    def tmp_redirect_logs():
201        try:
202            logger = torch._logging.getArtifactLogger(module, log_option)
203            logger.addHandler(handler)
204            yield
205        finally:
206            logger.removeHandler(handler)
207
208    def ctx_manager():
209        exit_stack = log_settings(log_option)
210        exit_stack.enter_context(tmp_redirect_logs())
211        return exit_stack
212
213    return log_stream, ctx_manager
214