# Owner(s): ["oncall: export"] import torch from torch._dispatch.python import enable_python_dispatcher from torch._subclasses.schema_check_mode import SchemaCheckMode from torch.fx.operator_schemas import normalize_function from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import TestCase from torch.utils._pytree import tree_map # Simplified naming for C++ classes SchemaArgument = torch._C._SchemaArgument SchemaArgType = torch._C._SchemaArgType SchemaInfo = torch._C._SchemaInfo test_classes = {} class PreDispatchSchemaCheckMode(SchemaCheckMode): """ Dispatch mode built on top of SchemaCheckMode that checks for incorrect op schemas for PreDispatch IR. This is meant to run ops in eager mode on concrete inputs, to see if they incorrectly claim to be functional (aliasing or mutating). If an op is claimed to be functional and either is detected, an error is raised. Errors will be silenced if the schema admits aliasing or mutation - the op may later decompose and become functional. """ def __init__(self) -> None: self._dispatch_key = torch._C.DispatchKey.PreDispatch super().__init__() def _may_alias_or_mutate(self, func, types, args, kwargs): def unwrap(e): if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: try: return e.elem except AttributeError as t: return e return e # get arguments, outputs schema_info = SchemaInfo(func._schema) pre_arguments = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True ).kwargs schema_info.add_argument_values(pre_arguments) out = func(*args, **kwargs) tuple_out = out if isinstance(out, tuple) else (out,) tuple_out = tree_map(unwrap, tuple_out) # check schema for i in range(len(func._schema.arguments)): for j in range(len(tuple_out)): if schema_info.may_contain_alias( SchemaArgument(SchemaArgType.output, j), SchemaArgument(SchemaArgType.input, i), ): return True if schema_info.is_mutable( SchemaArgument(SchemaArgType.input, i), ): return True return False # creating this just so we have access to the offending op def __torch_dispatch__(self, func, types, args=(), kwargs=None): try: return super().__torch_dispatch__(func, types, args=args, kwargs=kwargs) except RuntimeError as e: # check if schema claims to be either aliasing or mutating alias_or_mutate = self._may_alias_or_mutate(func, types, args, kwargs) if ( not alias_or_mutate ): # if schema is aliasing or mutating, will decompose further msg = e.args[0] e.args = ( f"""SchemaCheckMode failed with the following error on op <{func}>, meaning this op contains aliasing or mutations, despite claiming to be functional:\n\n""" + msg, ) raise e class TestOpInfo(TestCase): @ops(op_db, allowed_dtypes=(torch.float, torch.int)) def test_schema_check_op(self, device, dtype, op): sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) inputs = next(sample_inputs_itr) args = [inputs.input] + list(inputs.args) kwargs = inputs.kwargs with enable_python_dispatcher(): with PreDispatchSchemaCheckMode(): op.op(*args, **kwargs) instantiate_device_type_tests(TestOpInfo, globals()) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()