# Owner(s): ["oncall: export"] # flake8: noqa import copy import io import unittest import torch import torch._dynamo as torchdynamo import torch.utils._pytree as pytree from torch._dynamo.test_case import TestCase from torch.export import export, load, save from torch.export._trace import _export from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, ) from torch.testing._internal.common_utils import ( IS_WINDOWS, run_tests, TestCase as TorchTestCase, ) from torch.testing._internal.hop_db import ( hop_db, hop_that_doesnt_have_opinfo_test_allowlist, ) hop_tests = [] for op_info in hop_db: op_info_hop_name = op_info.name if op_info_hop_name in hop_that_doesnt_have_opinfo_test_allowlist: continue hop_tests.append(op_info) class TestHOPGeneric(TestCase): def test_all_hops_have_op_info(self): from torch._ops import _higher_order_ops hops_that_have_op_info = set([k.name for k in hop_db]) all_hops = _higher_order_ops.keys() missing_ops = [] for op in all_hops: if ( op not in hops_that_have_op_info and op not in hop_that_doesnt_have_opinfo_test_allowlist ): missing_ops.append(op) self.assertTrue(len(missing_ops) == 0, f"Missing op info for {missing_ops}") @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestHOP(TestCase): def _compare(self, eager_model, export, args, kwargs): eager_args = copy.deepcopy(args) eager_kwargs = copy.deepcopy(kwargs) export_args = copy.deepcopy(args) export_kwargs = copy.deepcopy(kwargs) flat_orig_outputs = pytree.tree_leaves(eager_model(*eager_args, **eager_kwargs)) flat_loaded_outputs = pytree.tree_leaves( export.module()(*export_args, **export_kwargs) ) for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs): self.assertEqual(type(orig), type(loaded)) self.assertEqual(orig, loaded) @ops(hop_tests, allowed_dtypes=(torch.float,)) def test_aot_export(self, device, dtype, op): class Foo(torch.nn.Module): def forward(self, *args): return op.op(*args) sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) for inp in sample_inputs_itr: model = Foo() input = inp.input if isinstance(inp.input, tuple) else (inp.input,) args = (*input, *inp.args) kwargs = inp.kwargs ep = export(model, args, kwargs) self._compare(model, ep, args, kwargs) @ops(hop_tests, allowed_dtypes=(torch.float,)) def test_pre_dispatch_export(self, device, dtype, op): class Foo(torch.nn.Module): def forward(self, *args): return op.op(*args) sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) for inp in sample_inputs_itr: model = Foo() input = inp.input if isinstance(inp.input, tuple) else (inp.input,) args = (*input, *inp.args) kwargs = inp.kwargs ep = _export(model, args, kwargs, pre_dispatch=True) self._compare(model, ep, args, kwargs) @ops(hop_tests, allowed_dtypes=(torch.float,)) def test_retrace_export(self, device, dtype, op): class Foo(torch.nn.Module): def forward(self, *args): return op.op(*args) sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) for inp in sample_inputs_itr: model = Foo() input = inp.input if isinstance(inp.input, tuple) else (inp.input,) args = (*input, *inp.args) kwargs = inp.kwargs ep = _export(model, args, kwargs, pre_dispatch=True) ep = ep.run_decompositions() self._compare(model, ep, args, kwargs) @ops(hop_tests, allowed_dtypes=(torch.float,)) def test_serialize_export(self, device, dtype, op): class Foo(torch.nn.Module): def forward(self, *args): return op.op(*args) sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) for inp in sample_inputs_itr: model = Foo() input = inp.input if isinstance(inp.input, tuple) else (inp.input,) args = (*input, *inp.args) kwargs = inp.kwargs ep = _export(model, args, kwargs, pre_dispatch=True) ep = ep.run_decompositions() buffer = io.BytesIO() save(ep, buffer) buffer.seek(0) ep = load(buffer) if "while_loop" in str(op): # while_loop's arguments are cast into list after deserailize # but while_loop expects it to still be tuple with self.assertRaisesRegex( RuntimeError, "carried_inputs must be a tuple" ): self._compare(model, ep, args, kwargs) else: self._compare(model, ep, args, kwargs) instantiate_device_type_tests(TestHOP, globals()) if __name__ == "__main__": run_tests()