xref: /aosp_15_r20/external/pytorch/test/export/opinfo_schema.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2
3import torch
4from torch._dispatch.python import enable_python_dispatcher
5from torch._subclasses.schema_check_mode import SchemaCheckMode
6from torch.fx.operator_schemas import normalize_function
7from torch.testing._internal.common_device_type import (
8    instantiate_device_type_tests,
9    ops,
10)
11from torch.testing._internal.common_methods_invocations import op_db
12from torch.testing._internal.common_utils import TestCase
13from torch.utils._pytree import tree_map
14
15
16# Simplified naming for C++ classes
17SchemaArgument = torch._C._SchemaArgument
18SchemaArgType = torch._C._SchemaArgType
19SchemaInfo = torch._C._SchemaInfo
20
21test_classes = {}
22
23
24class PreDispatchSchemaCheckMode(SchemaCheckMode):
25    """
26    Dispatch mode built on top of SchemaCheckMode that checks for incorrect op schemas
27    for PreDispatch IR. This is meant to run ops in eager mode on concrete inputs, to
28    see if they incorrectly claim to be functional (aliasing or mutating).
29
30    If an op is claimed to be functional and either is detected, an error is raised.
31    Errors will be silenced if the schema admits aliasing or mutation - the op may
32    later decompose and become functional.
33    """
34
35    def __init__(self) -> None:
36        self._dispatch_key = torch._C.DispatchKey.PreDispatch
37        super().__init__()
38
39    def _may_alias_or_mutate(self, func, types, args, kwargs):
40        def unwrap(e):
41            if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
42                try:
43                    return e.elem
44                except AttributeError as t:
45                    return e
46            return e
47
48        # get arguments, outputs
49        schema_info = SchemaInfo(func._schema)
50        pre_arguments = normalize_function(
51            func, args, kwargs, normalize_to_only_use_kwargs=True
52        ).kwargs
53        schema_info.add_argument_values(pre_arguments)
54        out = func(*args, **kwargs)
55        tuple_out = out if isinstance(out, tuple) else (out,)
56        tuple_out = tree_map(unwrap, tuple_out)
57
58        # check schema
59        for i in range(len(func._schema.arguments)):
60            for j in range(len(tuple_out)):
61                if schema_info.may_contain_alias(
62                    SchemaArgument(SchemaArgType.output, j),
63                    SchemaArgument(SchemaArgType.input, i),
64                ):
65                    return True
66            if schema_info.is_mutable(
67                SchemaArgument(SchemaArgType.input, i),
68            ):
69                return True
70
71        return False
72
73    # creating this just so we have access to the offending op
74    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
75        try:
76            return super().__torch_dispatch__(func, types, args=args, kwargs=kwargs)
77        except RuntimeError as e:
78            # check if schema claims to be either aliasing or mutating
79            alias_or_mutate = self._may_alias_or_mutate(func, types, args, kwargs)
80            if (
81                not alias_or_mutate
82            ):  # if schema is aliasing or mutating, will decompose further
83                msg = e.args[0]
84                e.args = (
85                    f"""SchemaCheckMode failed with the following error on op <{func}>, meaning
86    this op contains aliasing or mutations, despite claiming to be functional:\n\n"""
87                    + msg,
88                )
89                raise e
90
91
92class TestOpInfo(TestCase):
93    @ops(op_db, allowed_dtypes=(torch.float, torch.int))
94    def test_schema_check_op(self, device, dtype, op):
95        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
96        inputs = next(sample_inputs_itr)
97        args = [inputs.input] + list(inputs.args)
98        kwargs = inputs.kwargs
99        with enable_python_dispatcher():
100            with PreDispatchSchemaCheckMode():
101                op.op(*args, **kwargs)
102
103
104instantiate_device_type_tests(TestOpInfo, globals())
105
106if __name__ == "__main__":
107    from torch._dynamo.test_case import run_tests
108
109    run_tests()
110