1# Owner(s): ["module: dynamo"] 2import unittest 3import warnings 4 5from torch._dynamo import config 6from torch._dynamo.testing import make_test_cls_with_patches 7from torch.fx.experimental import _config as fx_config 8from torch.testing._internal.common_utils import slowTest, TEST_Z3 9 10 11try: 12 from . import ( 13 test_aot_autograd, 14 test_ctx_manager, 15 test_export, 16 test_functions, 17 test_higher_order_ops, 18 test_misc, 19 test_modules, 20 test_repros, 21 test_sdpa, 22 test_subgraphs, 23 ) 24except ImportError: 25 import test_aot_autograd 26 import test_ctx_manager 27 import test_export 28 import test_functions 29 import test_higher_order_ops 30 import test_misc 31 32 import test_modules 33 import test_repros 34 import test_sdpa 35 import test_subgraphs 36 37 38test_classes = {} 39 40 41def make_dynamic_cls(cls): 42 suffix = "_dynamic_shapes" 43 44 cls_prefix = "DynamicShapes" 45 46 test_class = make_test_cls_with_patches( 47 cls, 48 cls_prefix, 49 suffix, 50 (config, "assume_static_by_default", False), 51 (config, "specialize_int", False), 52 (fx_config, "translation_validation", TEST_Z3), 53 (fx_config, "check_shape_env_recorded_events", True), 54 (fx_config, "validate_shape_env_version_key", True), 55 xfail_prop="_expected_failure_dynamic", 56 ) 57 58 test_classes[test_class.__name__] = test_class 59 # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING 60 globals()[test_class.__name__] = test_class 61 test_class.__module__ = __name__ 62 return test_class 63 64 65tests = [ 66 test_ctx_manager.CtxManagerTests, 67 test_functions.FunctionTests, 68 test_misc.MiscTests, 69 test_repros.ReproTests, 70 test_modules.NNModuleTests, 71 test_export.ExportTests, 72 test_subgraphs.SubGraphTests, 73 test_higher_order_ops.HigherOrderOpTests, 74 test_higher_order_ops.FuncTorchHigherOrderOpTests, 75 test_aot_autograd.AotAutogradFallbackTests, 76 test_sdpa.TestSDPA, 77] 78for test in tests: 79 make_dynamic_cls(test) 80del test 81 82if TEST_Z3: 83 if not config.inline_inbuilt_nn_modules: 84 # TODO model is somehow not being freed when z3 is available 85 unittest.expectedFailure( 86 DynamicShapesMiscTests.test_parameter_free_dynamic_shapes # noqa: F821 87 ) 88 89unittest.expectedFailure( 90 # Test is only valid without dynamic shapes 91 DynamicShapesReproTests.test_many_views_with_mutation_dynamic_shapes # noqa: F821 92) 93 94# Test takes too long ~700s as of 414a1fd29f04d06e41b7f895368dd1f83a4be29d 95DynamicShapesExportTests.test_retracibility_dynamic_shapes = slowTest( # noqa: F821 96 DynamicShapesExportTests.test_retracibility_dynamic_shapes # noqa: F821 97) 98# Also take more than 30m as of 15cc9f2e7e7b2b175f24755925dc38d4d430905d 99DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes = slowTest( # noqa: F821 100 DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes # noqa: F821 101) 102DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes = slowTest( # noqa: F821 103 DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes # noqa: F821 104) 105 106if __name__ == "__main__": 107 from torch._dynamo.test_case import run_tests 108 109 if not TEST_Z3: 110 warnings.warn( 111 "translation validation is off. " 112 "Testing with translation validation requires Z3." 113 ) 114 115 run_tests() 116