xref: /aosp_15_r20/external/pytorch/test/export/test_export_nonstrict.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2
3try:
4    from . import test_export, testing
5except ImportError:
6    import test_export
7    import testing
8
9from torch.export import export
10
11
12test_classes = {}
13
14
15def mocked_non_strict_export(*args, **kwargs):
16    # If user already specified strict, don't make it non-strict
17    if "strict" in kwargs:
18        return export(*args, **kwargs)
19    return export(*args, **kwargs, strict=False)
20
21
22def make_dynamic_cls(cls):
23    cls_prefix = "NonStrictExport"
24
25    test_class = testing.make_test_cls_with_mocked_export(
26        cls,
27        cls_prefix,
28        test_export.NON_STRICT_SUFFIX,
29        mocked_non_strict_export,
30        xfail_prop="_expected_failure_non_strict",
31    )
32
33    test_classes[test_class.__name__] = test_class
34    # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
35    globals()[test_class.__name__] = test_class
36    test_class.__module__ = __name__
37    return test_class
38
39
40tests = [
41    test_export.TestDynamismExpression,
42    test_export.TestExport,
43]
44for test in tests:
45    make_dynamic_cls(test)
46del test
47
48if __name__ == "__main__":
49    from torch._dynamo.test_case import run_tests
50
51    run_tests()
52