# Owner(s): ["module: dynamo"] import unittest import warnings from torch._dynamo import config from torch._dynamo.testing import make_test_cls_with_patches from torch.fx.experimental import _config as fx_config from torch.testing._internal.common_utils import slowTest, TEST_Z3 try: from . import ( test_aot_autograd, test_ctx_manager, test_export, test_functions, test_higher_order_ops, test_misc, test_modules, test_repros, test_sdpa, test_subgraphs, ) except ImportError: import test_aot_autograd import test_ctx_manager import test_export import test_functions import test_higher_order_ops import test_misc import test_modules import test_repros import test_sdpa import test_subgraphs test_classes = {} def make_dynamic_cls(cls): suffix = "_dynamic_shapes" cls_prefix = "DynamicShapes" test_class = make_test_cls_with_patches( cls, cls_prefix, suffix, (config, "assume_static_by_default", False), (config, "specialize_int", False), (fx_config, "translation_validation", TEST_Z3), (fx_config, "check_shape_env_recorded_events", True), (fx_config, "validate_shape_env_version_key", True), xfail_prop="_expected_failure_dynamic", ) test_classes[test_class.__name__] = test_class # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING globals()[test_class.__name__] = test_class test_class.__module__ = __name__ return test_class tests = [ test_ctx_manager.CtxManagerTests, test_functions.FunctionTests, test_misc.MiscTests, test_repros.ReproTests, test_modules.NNModuleTests, test_export.ExportTests, test_subgraphs.SubGraphTests, test_higher_order_ops.HigherOrderOpTests, test_higher_order_ops.FuncTorchHigherOrderOpTests, test_aot_autograd.AotAutogradFallbackTests, test_sdpa.TestSDPA, ] for test in tests: make_dynamic_cls(test) del test if TEST_Z3: if not config.inline_inbuilt_nn_modules: # TODO model is somehow not being freed when z3 is available unittest.expectedFailure( DynamicShapesMiscTests.test_parameter_free_dynamic_shapes # noqa: F821 ) unittest.expectedFailure( # Test is only valid without dynamic shapes DynamicShapesReproTests.test_many_views_with_mutation_dynamic_shapes # noqa: F821 ) # Test takes too long ~700s as of 414a1fd29f04d06e41b7f895368dd1f83a4be29d DynamicShapesExportTests.test_retracibility_dynamic_shapes = slowTest( # noqa: F821 DynamicShapesExportTests.test_retracibility_dynamic_shapes # noqa: F821 ) # Also take more than 30m as of 15cc9f2e7e7b2b175f24755925dc38d4d430905d DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes = slowTest( # noqa: F821 DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes # noqa: F821 ) DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes = slowTest( # noqa: F821 DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes # noqa: F821 ) if __name__ == "__main__": from torch._dynamo.test_case import run_tests if not TEST_Z3: warnings.warn( "translation validation is off. " "Testing with translation validation requires Z3." ) run_tests()