xref: /aosp_15_r20/external/pytorch/test/dynamo/test_dynamic_shapes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import unittest
3import warnings
4
5from torch._dynamo import config
6from torch._dynamo.testing import make_test_cls_with_patches
7from torch.fx.experimental import _config as fx_config
8from torch.testing._internal.common_utils import slowTest, TEST_Z3
9
10
11try:
12    from . import (
13        test_aot_autograd,
14        test_ctx_manager,
15        test_export,
16        test_functions,
17        test_higher_order_ops,
18        test_misc,
19        test_modules,
20        test_repros,
21        test_sdpa,
22        test_subgraphs,
23    )
24except ImportError:
25    import test_aot_autograd
26    import test_ctx_manager
27    import test_export
28    import test_functions
29    import test_higher_order_ops
30    import test_misc
31
32    import test_modules
33    import test_repros
34    import test_sdpa
35    import test_subgraphs
36
37
38test_classes = {}
39
40
41def make_dynamic_cls(cls):
42    suffix = "_dynamic_shapes"
43
44    cls_prefix = "DynamicShapes"
45
46    test_class = make_test_cls_with_patches(
47        cls,
48        cls_prefix,
49        suffix,
50        (config, "assume_static_by_default", False),
51        (config, "specialize_int", False),
52        (fx_config, "translation_validation", TEST_Z3),
53        (fx_config, "check_shape_env_recorded_events", True),
54        (fx_config, "validate_shape_env_version_key", True),
55        xfail_prop="_expected_failure_dynamic",
56    )
57
58    test_classes[test_class.__name__] = test_class
59    # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
60    globals()[test_class.__name__] = test_class
61    test_class.__module__ = __name__
62    return test_class
63
64
65tests = [
66    test_ctx_manager.CtxManagerTests,
67    test_functions.FunctionTests,
68    test_misc.MiscTests,
69    test_repros.ReproTests,
70    test_modules.NNModuleTests,
71    test_export.ExportTests,
72    test_subgraphs.SubGraphTests,
73    test_higher_order_ops.HigherOrderOpTests,
74    test_higher_order_ops.FuncTorchHigherOrderOpTests,
75    test_aot_autograd.AotAutogradFallbackTests,
76    test_sdpa.TestSDPA,
77]
78for test in tests:
79    make_dynamic_cls(test)
80del test
81
82if TEST_Z3:
83    if not config.inline_inbuilt_nn_modules:
84        # TODO model is somehow not being freed when z3 is available
85        unittest.expectedFailure(
86            DynamicShapesMiscTests.test_parameter_free_dynamic_shapes  # noqa: F821
87        )
88
89unittest.expectedFailure(
90    # Test is only valid without dynamic shapes
91    DynamicShapesReproTests.test_many_views_with_mutation_dynamic_shapes  # noqa: F821
92)
93
94# Test takes too long ~700s as of 414a1fd29f04d06e41b7f895368dd1f83a4be29d
95DynamicShapesExportTests.test_retracibility_dynamic_shapes = slowTest(  # noqa: F821
96    DynamicShapesExportTests.test_retracibility_dynamic_shapes  # noqa: F821
97)
98# Also take more than 30m as of 15cc9f2e7e7b2b175f24755925dc38d4d430905d
99DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes = slowTest(  # noqa: F821
100    DynamicShapesExportTests.test_retracibility_dict_container_inp_out_dynamic_shapes  # noqa: F821
101)
102DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes = slowTest(  # noqa: F821
103    DynamicShapesExportTests.test_retracibility_nested_list_out_dynamic_shapes  # noqa: F821
104)
105
106if __name__ == "__main__":
107    from torch._dynamo.test_case import run_tests
108
109    if not TEST_Z3:
110        warnings.warn(
111            "translation validation is off. "
112            "Testing with translation validation requires Z3."
113        )
114
115    run_tests()
116