# Owner(s): ["module: onnx"] """Test consistency between the output values of torch.onnx exported operators and torch operators given the same inputs. Usage: pytest test/onnx/test_op_consistency.py To run tests on a specific operator (e.g. torch.ceil): pytest test/onnx/test_op_consistency.py -k ceil pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention Read more on Running and writing tests: https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests Note: When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and TESTED_OPS lists. See "Modify this section" """ from __future__ import annotations import copy from typing import Optional, Tuple import onnx_test_common import parameterized # For readability, these two are allowed to be imported as function from onnx_test_common import skip, xfail import torch from torch.testing._internal import ( common_device_type, common_methods_invocations, common_utils, ) OPS_DB = copy.deepcopy(common_methods_invocations.op_db) # Modify this section ########################################################## # NOTE: Modify this section as more ops are supported. The list should be sorted # alphabetically. # # For example, to add a test for torch.ceil: # 1. Add "ceil" to TESTED_OPS then run pytest. # 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS. # TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled. # Ops to be tested for numerical consistency between onnx and pytorch # TODO: https://github.com/pytorch/pytorch/issues/102211 TESTED_OPS: frozenset[str] = frozenset( [ "atan", "atan2", # "atleast_1d", # How to support list input? # "atleast_2d", # "atleast_3d", "broadcast_to", "ceil", "expand", "flatten", "hstack", "logical_not", # "logit", "nn.functional.scaled_dot_product_attention", "repeat", "round", # "scatter_add", # "scatter_reduce", "sqrt", "stft", "t", "tile", "unflatten", "vstack", ] ) # fmt: off # Turn off black formatting to keep the list compact # Expected failures for onnx export. # The list should be sorted alphabetically by op name. # Q: When should I use fixme vs vs skip vs xfail? # A: Prefer xfail over skip when possible. # 2a. If a test is now failing because of xpass, because some previous errors # are now fixed, removed the corresponding xfail. # 2b. If a test is not failing consistently, use skip. EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = ( skip( "atan", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Atan") ), xfail("atan", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"])), skip( "atan2", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Atan") ), xfail( "atan2", dtypes=[torch.float64], reason=onnx_test_common.reason_onnx_runtime_does_not_support("Atan", ["f64"]) ), xfail( "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Ceil") ), skip("hstack", opsets=[onnx_test_common.opsets_before(11)], reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")), xfail( "logit", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Log", "bool, int"), ), skip("nn.functional.scaled_dot_product_attention", opsets=[onnx_test_common.opsets_before(14)], reason="Need Trilu."), skip("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"), xfail("round", opsets=[onnx_test_common.opsets_before(11)], reason=onnx_test_common.reason_onnx_does_not_support("Round")), xfail("round", variant_name="decimals_0", opsets=[onnx_test_common.opsets_before(11)], reason=onnx_test_common.reason_onnx_does_not_support("Round")), xfail("round", variant_name="decimals_3", opsets=[onnx_test_common.opsets_before(11)], reason=onnx_test_common.reason_onnx_does_not_support("Round")), xfail("round", variant_name="decimals_neg_3", opsets=[onnx_test_common.opsets_before(11)], reason=onnx_test_common.reason_onnx_does_not_support("Round")), skip("scatter_reduce", variant_name="amin", opsets=[onnx_test_common.opsets_before(16)], reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")), skip("scatter_reduce", variant_name="amax", opsets=[onnx_test_common.opsets_before(16)], reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")), skip("scatter_reduce", variant_name="prod", opsets=[onnx_test_common.opsets_before(16)], reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")), xfail("scatter_reduce", variant_name="mean", reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction=mean")), skip("scatter_reduce", variant_name="sum", opsets=[onnx_test_common.opsets_before(16)], reason=onnx_test_common.reason_onnx_does_not_support("ScatterElements with reduction")), xfail( "scatter_reduce", variant_name="sum", dtypes=(torch.float16,), reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), ), xfail( "scatter_reduce", variant_name="prod", dtypes=(torch.float16,), reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"), ), xfail( "scatter_reduce", variant_name="amin", dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"), ), xfail( "scatter_reduce", variant_name="amax", dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"), ), xfail( "scatter_reduce", variant_name="mean", reason="ONNX doesn't support reduce='mean' option", ), skip("sqrt", dtypes=onnx_test_common.BOOL_TYPES, reason=onnx_test_common.reason_onnx_does_not_support("Sqrt")), skip("stft", opsets=[onnx_test_common.opsets_before(17)], reason=onnx_test_common.reason_onnx_does_not_support("STFT")), xfail("stft", reason=onnx_test_common.reason_onnx_runtime_does_not_support("STFT", "Regression on ORT=1.15 4 percent difference")), skip("tile", opsets=[onnx_test_common.opsets_before(13)], reason=onnx_test_common.reason_onnx_does_not_support("Tile")), xfail("unflatten", opsets=[onnx_test_common.opsets_before(13)], reason="Helper function is needed to support legacy ops."), skip("vstack", opsets=[onnx_test_common.opsets_before(11)], reason=onnx_test_common.reason_onnx_does_not_support("ConcatFromSequence")), ) # fmt: on SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = ( skip( "nn.functional.scaled_dot_product_attention", matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, reason="dropout is random so the results do not match", ), skip( "repeat", reason="Empty repeats value leads to an invalid graph", matcher=lambda sample: not sample.args[0], ), skip( "scatter_reduce", # ONNX has not include_self parameter and default is include_self=True mode matcher=lambda sample: sample.kwargs.get("include_self") is False, reason="ONNX does't support include_self=False option", ), skip( "stft", reason="ONNX STFT does not support complex results", matcher=lambda sample: sample.kwargs.get("return_complex") is True, ), skip( "tile", matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) or not sample.input.shape, reason="Logic not implemented for size 0 inputs in op.Reshape", ), skip( "unflatten", reason="Logic not implemented for size 0 inputs in op.Reshape", matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), ), ) # END OF SECTION TO MODIFY ##################################################### OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS) ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) # Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB" class SingleOpModel(torch.nn.Module): """Test model to wrap around a single op for export.""" def __init__(self, op, kwargs): super().__init__() self.operator = op self.kwargs = kwargs def forward(self, *args): return self.operator(*args, **self.kwargs) def _should_skip_xfail_test_sample( op_name: str, sample ) -> Tuple[Optional[str], Optional[str]]: """Returns a reason if a test sample should be skipped.""" if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS: return None, None for decorator_meta in SKIP_XFAIL_SUBTESTS: # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small. if decorator_meta.op_name == op_name: assert decorator_meta.matcher is not None, "Matcher must be defined" if decorator_meta.matcher(sample): return decorator_meta.test_behavior, decorator_meta.reason return None, None def _get_test_class_name(cls, num, params_dict) -> str: del cls # unused del num # unused return params_dict["name"] @parameterized.parameterized_class( [ { "name": f"TestOnnxModelOutputConsistency_opset{opset}", "opset_version": opset, } for opset in onnx_test_common.TESTED_OPSETS ], class_name_func=_get_test_class_name, ) class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): """Test output consistency between exported ONNX models and PyTorch eager mode. This is a parameterized test suite. """ opset_version = -1 @common_device_type.ops( [op for op in OPS_DB if op.name in TESTED_OPS], allowed_dtypes=onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES + onnx_test_common.BOOL_TYPES, ) def test_output_match(self, device: str, dtype: torch.dtype, op): """Test the ONNX exporter.""" # device is provided by instantiate_device_type_tests, but we only want to run in cpu. assert device == "cpu" samples = op.sample_inputs( device, dtype, requires_grad=False, ) for i, cpu_sample in enumerate(samples): inputs = (cpu_sample.input, *cpu_sample.args) # Provide the repr to subtest because tensors are not serializable in parallel test runs with self.subTest( opset=self.opset_version, sample_num=i, inputs=repr(inputs), kwargs=repr(cpu_sample.kwargs), ): test_behavior, reason = _should_skip_xfail_test_sample( op.name, cpu_sample ) with onnx_test_common.normal_xfail_skip_test_behaviors( test_behavior, reason ): model = SingleOpModel(op, cpu_sample.kwargs) model.eval() if dtype == torch.float32: # Relax atol and rtol for float32 based on empirical results # The current most relaxed values are for aten::stft rtol = 1e-5 atol = 2e-5 elif dtype == torch.float64: # The current most relaxed values are for aten::stft rtol = 1e-5 atol = 2e-5 else: rtol = None atol = None # Run the test self.run_test(model, inputs, rtol=rtol, atol=atol) for opset in onnx_test_common.TESTED_OPSETS: # The name needs to match the parameterized_class name. test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}" onnx_test_common.add_decorate_info( OPS_DB, test_class_name, "test_output_match", opset=opset, skip_or_xfails=EXPECTED_SKIPS_OR_FAILS, ) common_device_type.instantiate_device_type_tests( globals()[test_class_name], globals(), only_for="cpu" ) if __name__ == "__main__": common_utils.run_tests()