xref: /aosp_15_r20/external/pytorch/test/export/test_retraceability.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2
3try:
4    from . import test_export, testing
5except ImportError:
6    import test_export
7    import testing
8
9from torch.export import export
10
11
12test_classes = {}
13
14
15def mocked_retraceability_export(*args, **kwargs):
16    ep = export(*args, **kwargs)
17    if "dynamic_shapes" in kwargs:
18        if isinstance(kwargs["dynamic_shapes"], dict):
19            kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
20
21    ep = export(ep.module(), *(args[1:]), **kwargs)
22    return ep
23
24
25def make_dynamic_cls(cls):
26    cls_prefix = "RetraceExport"
27
28    test_class = testing.make_test_cls_with_mocked_export(
29        cls,
30        cls_prefix,
31        test_export.RETRACEABILITY_SUFFIX,
32        mocked_retraceability_export,
33        xfail_prop="_expected_failure_retrace",
34    )
35
36    test_classes[test_class.__name__] = test_class
37    # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
38    globals()[test_class.__name__] = test_class
39    test_class.__module__ = __name__
40    return test_class
41
42
43tests = [
44    test_export.TestDynamismExpression,
45    test_export.TestExport,
46]
47for test in tests:
48    make_dynamic_cls(test)
49del test
50
51if __name__ == "__main__":
52    from torch._dynamo.test_case import run_tests
53
54    run_tests()
55