# Owner(s): ["oncall: export"] import copy import unittest import torch._dynamo as torchdynamo from torch._export.db.case import ExportCase, SupportLevel from torch._export.db.examples import ( filter_examples_by_support_level, get_rewrite_cases, ) from torch.export import export from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, parametrize, run_tests, TestCase, ) @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") class ExampleTests(TestCase): # TODO Maybe we should make this tests actually show up in a file? @parametrize( "name,case", filter_examples_by_support_level(SupportLevel.SUPPORTED).items(), name_fn=lambda name, case: f"case_{name}", ) def test_exportdb_supported(self, name: str, case: ExportCase) -> None: model = case.model args_export = case.example_args kwargs_export = case.example_kwargs args_model = copy.deepcopy(args_export) kwargs_model = copy.deepcopy(kwargs_export) exported_program = export( model, args_export, kwargs_export, dynamic_shapes=case.dynamic_shapes, ) exported_program.graph_module.print_readable() self.assertEqual( exported_program.module()(*args_export, **kwargs_export), model(*args_model, **kwargs_model), ) if case.extra_args is not None: args = case.extra_args args_model = copy.deepcopy(args) self.assertEqual( exported_program.module()(*args), model(*args_model), ) @parametrize( "name,case", filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(), name_fn=lambda name, case: f"case_{name}", ) def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None: model = case.model # pyre-ignore with self.assertRaises( (torchdynamo.exc.Unsupported, AssertionError, RuntimeError) ): export( model, case.example_args, case.example_kwargs, dynamic_shapes=case.dynamic_shapes, ) exportdb_not_supported_rewrite_cases = [ (name, rewrite_case) for name, case in filter_examples_by_support_level( SupportLevel.NOT_SUPPORTED_YET ).items() for rewrite_case in get_rewrite_cases(case) ] if exportdb_not_supported_rewrite_cases: @parametrize( "name,rewrite_case", exportdb_not_supported_rewrite_cases, name_fn=lambda name, case: f"case_{name}_{case.name}", ) def test_exportdb_not_supported_rewrite( self, name: str, rewrite_case: ExportCase ) -> None: # pyre-ignore export( rewrite_case.model, rewrite_case.example_args, rewrite_case.example_kwargs, dynamic_shapes=rewrite_case.dynamic_shapes, ) instantiate_parametrized_tests(ExampleTests) if __name__ == "__main__": run_tests()