xref: /aosp_15_r20/external/pytorch/torch/_dynamo/test_case.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import importlib
4import logging
5
6import torch
7import torch.testing
8from torch.testing._internal.common_utils import (  # type: ignore[attr-defined]
9    IS_WINDOWS,
10    TEST_WITH_CROSSREF,
11    TEST_WITH_TORCHDYNAMO,
12    TestCase as TorchTestCase,
13)
14
15from . import config, reset, utils
16
17
18log = logging.getLogger(__name__)
19
20
21def run_tests(needs=()):
22    from torch.testing._internal.common_utils import run_tests
23
24    if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF:
25        return  # skip testing
26
27    if isinstance(needs, str):
28        needs = (needs,)
29    for need in needs:
30        if need == "cuda":
31            if not torch.cuda.is_available():
32                return
33        else:
34            try:
35                importlib.import_module(need)
36            except ImportError:
37                return
38    run_tests()
39
40
41class TestCase(TorchTestCase):
42    _exit_stack: contextlib.ExitStack
43
44    @classmethod
45    def tearDownClass(cls):
46        cls._exit_stack.close()
47        super().tearDownClass()
48
49    @classmethod
50    def setUpClass(cls):
51        super().setUpClass()
52        cls._exit_stack = contextlib.ExitStack()  # type: ignore[attr-defined]
53        cls._exit_stack.enter_context(  # type: ignore[attr-defined]
54            config.patch(
55                raise_on_ctx_manager_usage=True,
56                suppress_errors=False,
57                log_compilation_metrics=False,
58            ),
59        )
60
61    def setUp(self):
62        self._prior_is_grad_enabled = torch.is_grad_enabled()
63        super().setUp()
64        reset()
65        utils.counters.clear()
66
67    def tearDown(self):
68        for k, v in utils.counters.items():
69            print(k, v.most_common())
70        reset()
71        utils.counters.clear()
72        super().tearDown()
73        if self._prior_is_grad_enabled is not torch.is_grad_enabled():
74            log.warning("Running test changed grad mode")
75            torch.set_grad_enabled(self._prior_is_grad_enabled)
76