1# Owner(s): ["oncall: export"] 2 3try: 4 from . import test_export, testing 5except ImportError: 6 import test_export 7 import testing 8 9from torch.export import export 10 11 12test_classes = {} 13 14 15def mocked_retraceability_export(*args, **kwargs): 16 ep = export(*args, **kwargs) 17 if "dynamic_shapes" in kwargs: 18 if isinstance(kwargs["dynamic_shapes"], dict): 19 kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values()) 20 21 ep = export(ep.module(), *(args[1:]), **kwargs) 22 return ep 23 24 25def make_dynamic_cls(cls): 26 cls_prefix = "RetraceExport" 27 28 test_class = testing.make_test_cls_with_mocked_export( 29 cls, 30 cls_prefix, 31 test_export.RETRACEABILITY_SUFFIX, 32 mocked_retraceability_export, 33 xfail_prop="_expected_failure_retrace", 34 ) 35 36 test_classes[test_class.__name__] = test_class 37 # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING 38 globals()[test_class.__name__] = test_class 39 test_class.__module__ = __name__ 40 return test_class 41 42 43tests = [ 44 test_export.TestDynamismExpression, 45 test_export.TestExport, 46] 47for test in tests: 48 make_dynamic_cls(test) 49del test 50 51if __name__ == "__main__": 52 from torch._dynamo.test_case import run_tests 53 54 run_tests() 55