1# mypy: allow-untyped-defs 2import contextlib 3import importlib 4import logging 5 6import torch 7import torch.testing 8from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] 9 IS_WINDOWS, 10 TEST_WITH_CROSSREF, 11 TEST_WITH_TORCHDYNAMO, 12 TestCase as TorchTestCase, 13) 14 15from . import config, reset, utils 16 17 18log = logging.getLogger(__name__) 19 20 21def run_tests(needs=()): 22 from torch.testing._internal.common_utils import run_tests 23 24 if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF: 25 return # skip testing 26 27 if isinstance(needs, str): 28 needs = (needs,) 29 for need in needs: 30 if need == "cuda": 31 if not torch.cuda.is_available(): 32 return 33 else: 34 try: 35 importlib.import_module(need) 36 except ImportError: 37 return 38 run_tests() 39 40 41class TestCase(TorchTestCase): 42 _exit_stack: contextlib.ExitStack 43 44 @classmethod 45 def tearDownClass(cls): 46 cls._exit_stack.close() 47 super().tearDownClass() 48 49 @classmethod 50 def setUpClass(cls): 51 super().setUpClass() 52 cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined] 53 cls._exit_stack.enter_context( # type: ignore[attr-defined] 54 config.patch( 55 raise_on_ctx_manager_usage=True, 56 suppress_errors=False, 57 log_compilation_metrics=False, 58 ), 59 ) 60 61 def setUp(self): 62 self._prior_is_grad_enabled = torch.is_grad_enabled() 63 super().setUp() 64 reset() 65 utils.counters.clear() 66 67 def tearDown(self): 68 for k, v in utils.counters.items(): 69 print(k, v.most_common()) 70 reset() 71 utils.counters.clear() 72 super().tearDown() 73 if self._prior_is_grad_enabled is not torch.is_grad_enabled(): 74 log.warning("Running test changed grad mode") 75 torch.set_grad_enabled(self._prior_is_grad_enabled) 76