1# Owner(s): ["oncall: export"] 2 3import torch 4from torch._dispatch.python import enable_python_dispatcher 5from torch._subclasses.schema_check_mode import SchemaCheckMode 6from torch.fx.operator_schemas import normalize_function 7from torch.testing._internal.common_device_type import ( 8 instantiate_device_type_tests, 9 ops, 10) 11from torch.testing._internal.common_methods_invocations import op_db 12from torch.testing._internal.common_utils import TestCase 13from torch.utils._pytree import tree_map 14 15 16# Simplified naming for C++ classes 17SchemaArgument = torch._C._SchemaArgument 18SchemaArgType = torch._C._SchemaArgType 19SchemaInfo = torch._C._SchemaInfo 20 21test_classes = {} 22 23 24class PreDispatchSchemaCheckMode(SchemaCheckMode): 25 """ 26 Dispatch mode built on top of SchemaCheckMode that checks for incorrect op schemas 27 for PreDispatch IR. This is meant to run ops in eager mode on concrete inputs, to 28 see if they incorrectly claim to be functional (aliasing or mutating). 29 30 If an op is claimed to be functional and either is detected, an error is raised. 31 Errors will be silenced if the schema admits aliasing or mutation - the op may 32 later decompose and become functional. 33 """ 34 35 def __init__(self) -> None: 36 self._dispatch_key = torch._C.DispatchKey.PreDispatch 37 super().__init__() 38 39 def _may_alias_or_mutate(self, func, types, args, kwargs): 40 def unwrap(e): 41 if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: 42 try: 43 return e.elem 44 except AttributeError as t: 45 return e 46 return e 47 48 # get arguments, outputs 49 schema_info = SchemaInfo(func._schema) 50 pre_arguments = normalize_function( 51 func, args, kwargs, normalize_to_only_use_kwargs=True 52 ).kwargs 53 schema_info.add_argument_values(pre_arguments) 54 out = func(*args, **kwargs) 55 tuple_out = out if isinstance(out, tuple) else (out,) 56 tuple_out = tree_map(unwrap, tuple_out) 57 58 # check schema 59 for i in range(len(func._schema.arguments)): 60 for j in range(len(tuple_out)): 61 if schema_info.may_contain_alias( 62 SchemaArgument(SchemaArgType.output, j), 63 SchemaArgument(SchemaArgType.input, i), 64 ): 65 return True 66 if schema_info.is_mutable( 67 SchemaArgument(SchemaArgType.input, i), 68 ): 69 return True 70 71 return False 72 73 # creating this just so we have access to the offending op 74 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 75 try: 76 return super().__torch_dispatch__(func, types, args=args, kwargs=kwargs) 77 except RuntimeError as e: 78 # check if schema claims to be either aliasing or mutating 79 alias_or_mutate = self._may_alias_or_mutate(func, types, args, kwargs) 80 if ( 81 not alias_or_mutate 82 ): # if schema is aliasing or mutating, will decompose further 83 msg = e.args[0] 84 e.args = ( 85 f"""SchemaCheckMode failed with the following error on op <{func}>, meaning 86 this op contains aliasing or mutations, despite claiming to be functional:\n\n""" 87 + msg, 88 ) 89 raise e 90 91 92class TestOpInfo(TestCase): 93 @ops(op_db, allowed_dtypes=(torch.float, torch.int)) 94 def test_schema_check_op(self, device, dtype, op): 95 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 96 inputs = next(sample_inputs_itr) 97 args = [inputs.input] + list(inputs.args) 98 kwargs = inputs.kwargs 99 with enable_python_dispatcher(): 100 with PreDispatchSchemaCheckMode(): 101 op.op(*args, **kwargs) 102 103 104instantiate_device_type_tests(TestOpInfo, globals()) 105 106if __name__ == "__main__": 107 from torch._dynamo.test_case import run_tests 108 109 run_tests() 110