xref: /aosp_15_r20/external/pytorch/test/export/test_serdes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2
3import io
4
5
6try:
7    from . import test_export, testing
8except ImportError:
9    import test_export
10    import testing
11
12from torch.export import export, load, save
13
14
15test_classes = {}
16
17
18def mocked_serder_export(*args, **kwargs):
19    ep = export(*args, **kwargs)
20    buffer = io.BytesIO()
21    save(ep, buffer)
22    buffer.seek(0)
23    loaded_ep = load(buffer)
24    return loaded_ep
25
26
27def make_dynamic_cls(cls):
28    cls_prefix = "SerDesExport"
29
30    test_class = testing.make_test_cls_with_mocked_export(
31        cls,
32        cls_prefix,
33        test_export.SERDES_SUFFIX,
34        mocked_serder_export,
35        xfail_prop="_expected_failure_serdes",
36    )
37
38    test_classes[test_class.__name__] = test_class
39    # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
40    globals()[test_class.__name__] = test_class
41    test_class.__module__ = __name__
42
43
44tests = [
45    test_export.TestDynamismExpression,
46    test_export.TestExport,
47]
48for test in tests:
49    make_dynamic_cls(test)
50del test
51
52if __name__ == "__main__":
53    from torch._dynamo.test_case import run_tests
54
55    run_tests()
56