xref: /aosp_15_r20/external/executorch/exir/verification/verifier.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport itertools
8*523fa7a6SAndroid Build Coastguard Workerimport operator
9*523fa7a6SAndroid Build Coastguard Workerimport types
10*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import nullcontext
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, List, Optional, Tuple, Type
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.capture._config import EdgeCompileConfig
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import LoweredBackendModule
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.arg_validator import (
19*523fa7a6SAndroid Build Coastguard Worker    EdgeOpArgValidator,
20*523fa7a6SAndroid Build Coastguard Worker    RunHigherOrderOperatorError,
21*523fa7a6SAndroid Build Coastguard Worker)
22*523fa7a6SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher
23*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import _detect_fake_mode_from_gm
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.verifier import SpecViolationError, Verifier
26*523fa7a6SAndroid Build Coastguard Workerfrom torch._ops import OpOverload
27*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses import FakeTensor
28*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram
29*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import GraphModule
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard WorkerALLOWED_META_KEYS = {"spec", "stack_trace"}
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Workerdef _check_tensors_are_contiguous(gm: GraphModule) -> None:
36*523fa7a6SAndroid Build Coastguard Worker    # Tensors be of contiguous format
37*523fa7a6SAndroid Build Coastguard Worker    for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()):
38*523fa7a6SAndroid Build Coastguard Worker        if isinstance(param, torch.Tensor):
39*523fa7a6SAndroid Build Coastguard Worker            if not param.is_contiguous():
40*523fa7a6SAndroid Build Coastguard Worker                raise SpecViolationError(
41*523fa7a6SAndroid Build Coastguard Worker                    f"Tensors in Aten dialect must be contiguous, {name} is not contiguous"
42*523fa7a6SAndroid Build Coastguard Worker                )
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Workerdef _check_valid_dim_order_ops(op, use_dim_order) -> None:
46*523fa7a6SAndroid Build Coastguard Worker    if use_dim_order:
47*523fa7a6SAndroid Build Coastguard Worker        if op in (torch.ops.aten._to_copy.default,):
48*523fa7a6SAndroid Build Coastguard Worker            raise SpecViolationError(f"{op} should not be used in dim_order mode")
49*523fa7a6SAndroid Build Coastguard Worker    else:  # not using dim_order
50*523fa7a6SAndroid Build Coastguard Worker        if op.namespace in ("dim_order_ops",):
51*523fa7a6SAndroid Build Coastguard Worker            raise SpecViolationError(f"{op} should not be used in non-dim_order mode")
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Workerclass EXIRATenDialectVerifierBase(Verifier):
55*523fa7a6SAndroid Build Coastguard Worker    dialect = "OLD_EXIR_ATEN_DISABLED"
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker    def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
58*523fa7a6SAndroid Build Coastguard Worker        return (
59*523fa7a6SAndroid Build Coastguard Worker            torch.fx.GraphModule,
60*523fa7a6SAndroid Build Coastguard Worker            LoweredBackendModule,
61*523fa7a6SAndroid Build Coastguard Worker            torch.Tensor,
62*523fa7a6SAndroid Build Coastguard Worker            torch.ScriptObject,
63*523fa7a6SAndroid Build Coastguard Worker        )
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker    def allowed_op_types(self):
66*523fa7a6SAndroid Build Coastguard Worker        return super().allowed_op_types() + (torch._ops.OpOverloadPacket,)
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker    def __call__(self, *args, **kwargs):
69*523fa7a6SAndroid Build Coastguard Worker        if hasattr(self, "_check_graph_module"):
70*523fa7a6SAndroid Build Coastguard Worker            return self._check_graph_module(*args, **kwargs)
71*523fa7a6SAndroid Build Coastguard Worker        elif hasattr(self, "check_valid"):
72*523fa7a6SAndroid Build Coastguard Worker            return self.check_valid(*args, **kwargs)
73*523fa7a6SAndroid Build Coastguard Worker        else:
74*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError("")
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Workerdef EXIRATenDialectVerifier(  # noqa: C901
78*523fa7a6SAndroid Build Coastguard Worker    edge_compile_config: Optional[EdgeCompileConfig] = None,
79*523fa7a6SAndroid Build Coastguard Worker    class_only: bool = False,
80*523fa7a6SAndroid Build Coastguard Worker    exception_list: Optional[List[torch._ops.OpOverload]] = None,
81*523fa7a6SAndroid Build Coastguard Worker):
82*523fa7a6SAndroid Build Coastguard Worker    """
83*523fa7a6SAndroid Build Coastguard Worker    Returns a verifier class that runs ATen dialect specific checks on the graph module.
84*523fa7a6SAndroid Build Coastguard Worker    """
85*523fa7a6SAndroid Build Coastguard Worker    # merge the exception list from edge_compile_config and exception_list
86*523fa7a6SAndroid Build Coastguard Worker    if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
87*523fa7a6SAndroid Build Coastguard Worker        exception_list = edge_compile_config._core_aten_ops_exception_list + (
88*523fa7a6SAndroid Build Coastguard Worker            exception_list or []
89*523fa7a6SAndroid Build Coastguard Worker        )
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker    class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
92*523fa7a6SAndroid Build Coastguard Worker        dialect = "OLD_EXIR_ATEN"
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker        def __init__(self) -> None:
95*523fa7a6SAndroid Build Coastguard Worker            super().__init__()
96*523fa7a6SAndroid Build Coastguard Worker            # Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
97*523fa7a6SAndroid Build Coastguard Worker            self._exception_list = exception_list if exception_list else []
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker        def _get_exception_list(self) -> List[torch._ops.OpOverload]:
100*523fa7a6SAndroid Build Coastguard Worker            exception_list = [
101*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.mkldnn_rnn_layer.default,
102*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten._upsample_bilinear2d_aa.default,
103*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.quantize_per_tensor.default,
104*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.dequantize.self,
105*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.max.default,  # TODO(T188268054)
106*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.min.default,  # TODO(T188268054)
107*523fa7a6SAndroid Build Coastguard Worker                torch.ops.aten.full_like.default,  # TODO(T183507359)
108*523fa7a6SAndroid Build Coastguard Worker            ]
109*523fa7a6SAndroid Build Coastguard Worker            exception_list += self._exception_list
110*523fa7a6SAndroid Build Coastguard Worker
111*523fa7a6SAndroid Build Coastguard Worker            return exception_list
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker        def check_valid_op(self, op):
114*523fa7a6SAndroid Build Coastguard Worker            if isinstance(op, OpOverload):
115*523fa7a6SAndroid Build Coastguard Worker                # TODO These special ops should be removable easily.
116*523fa7a6SAndroid Build Coastguard Worker                if op.namespace != "aten" or op in self._get_exception_list():
117*523fa7a6SAndroid Build Coastguard Worker                    return
118*523fa7a6SAndroid Build Coastguard Worker                if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
119*523fa7a6SAndroid Build Coastguard Worker                    # NOTE(qihan): whether view_copy operators are marked as canonical is still under
120*523fa7a6SAndroid Build Coastguard Worker                    #            discussion.
121*523fa7a6SAndroid Build Coastguard Worker                    raise SpecViolationError(
122*523fa7a6SAndroid Build Coastguard Worker                        f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
123*523fa7a6SAndroid Build Coastguard Worker                    )
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker    ret = _EXIRATenDialectVerifier
126*523fa7a6SAndroid Build Coastguard Worker    if not class_only:
127*523fa7a6SAndroid Build Coastguard Worker        ret = ret()
128*523fa7a6SAndroid Build Coastguard Worker    return ret
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Workerdef get_aten_verifier(config: EdgeCompileConfig):
132*523fa7a6SAndroid Build Coastguard Worker    return (
133*523fa7a6SAndroid Build Coastguard Worker        EXIRATenDialectVerifier(
134*523fa7a6SAndroid Build Coastguard Worker            class_only=True, exception_list=config._core_aten_ops_exception_list
135*523fa7a6SAndroid Build Coastguard Worker        )
136*523fa7a6SAndroid Build Coastguard Worker        if config._check_ir_validity
137*523fa7a6SAndroid Build Coastguard Worker        else EXIRATenDialectVerifierBase
138*523fa7a6SAndroid Build Coastguard Worker    )
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Workerdef _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]:
142*523fa7a6SAndroid Build Coastguard Worker    def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
143*523fa7a6SAndroid Build Coastguard Worker        if "val" in node.meta:
144*523fa7a6SAndroid Build Coastguard Worker            return node.meta["val"]
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker        if len(node.users) == 0:
147*523fa7a6SAndroid Build Coastguard Worker            return None
148*523fa7a6SAndroid Build Coastguard Worker
149*523fa7a6SAndroid Build Coastguard Worker        # TODO(ycao): `val` should always exist after we enable shape environment
150*523fa7a6SAndroid Build Coastguard Worker        # serialization and deserialization.
151*523fa7a6SAndroid Build Coastguard Worker        raise ExportError(
152*523fa7a6SAndroid Build Coastguard Worker            ExportErrorType.VIOLATION_OF_SPEC,
153*523fa7a6SAndroid Build Coastguard Worker            f"Cannot construct an input for graph module: {graph_module}.",
154*523fa7a6SAndroid Build Coastguard Worker        )
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker    return [
157*523fa7a6SAndroid Build Coastguard Worker        extract_input(node)
158*523fa7a6SAndroid Build Coastguard Worker        for node in graph_module.graph.nodes
159*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder"
160*523fa7a6SAndroid Build Coastguard Worker    ]
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Workerdef _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
164*523fa7a6SAndroid Build Coastguard Worker    validator = EdgeOpArgValidator(gm)
165*523fa7a6SAndroid Build Coastguard Worker    inputs = _get_inputs(gm)
166*523fa7a6SAndroid Build Coastguard Worker    fake_mode = _detect_fake_mode_from_gm(gm) or nullcontext()
167*523fa7a6SAndroid Build Coastguard Worker    try:
168*523fa7a6SAndroid Build Coastguard Worker        with enable_python_dispatcher(), fake_mode:
169*523fa7a6SAndroid Build Coastguard Worker            validator.run(*inputs)
170*523fa7a6SAndroid Build Coastguard Worker    except RunHigherOrderOperatorError:
171*523fa7a6SAndroid Build Coastguard Worker        # NB: ignore higher order operator in the graph.
172*523fa7a6SAndroid Build Coastguard Worker        # If we lower a graph module to delegate and then compose it with some other graph module, retrace it,
173*523fa7a6SAndroid Build Coastguard Worker        # if we also turn on edge ops and validator (_check_ir_validity=True), we will run
174*523fa7a6SAndroid Build Coastguard Worker        # into RunHigherOrderOperatorError. The only thing we can do right now is to ignore this error, since
175*523fa7a6SAndroid Build Coastguard Worker        # by definition it's still a valid Edge dialect. This is not ideal because it ignores possible invalidity
176*523fa7a6SAndroid Build Coastguard Worker        # later in the graph.
177*523fa7a6SAndroid Build Coastguard Worker        return
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Worker    if validator.violating_ops:
180*523fa7a6SAndroid Build Coastguard Worker        raise SpecViolationError(
181*523fa7a6SAndroid Build Coastguard Worker            f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}"
182*523fa7a6SAndroid Build Coastguard Worker        )
183*523fa7a6SAndroid Build Coastguard Worker
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Workerdef EXIREdgeDialectVerifier(  # noqa: C901
186*523fa7a6SAndroid Build Coastguard Worker    edge_compile_config: Optional[EdgeCompileConfig] = None,
187*523fa7a6SAndroid Build Coastguard Worker    class_only: bool = False,
188*523fa7a6SAndroid Build Coastguard Worker    exception_list: Optional[List[torch._ops.OpOverload]] = None,
189*523fa7a6SAndroid Build Coastguard Worker):
190*523fa7a6SAndroid Build Coastguard Worker    # merge the exception list from edge_compile_config and exception_list
191*523fa7a6SAndroid Build Coastguard Worker    if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
192*523fa7a6SAndroid Build Coastguard Worker        exception_list = edge_compile_config._core_aten_ops_exception_list + (
193*523fa7a6SAndroid Build Coastguard Worker            exception_list or []
194*523fa7a6SAndroid Build Coastguard Worker        )
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker    class _EXIREdgeDialectVerifier(Verifier):
197*523fa7a6SAndroid Build Coastguard Worker        dialect = "EDGE"
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Worker        def __init__(self) -> None:
200*523fa7a6SAndroid Build Coastguard Worker            _edge_compile_config = edge_compile_config or EdgeCompileConfig()
201*523fa7a6SAndroid Build Coastguard Worker
202*523fa7a6SAndroid Build Coastguard Worker            self.enable = _edge_compile_config._check_ir_validity
203*523fa7a6SAndroid Build Coastguard Worker            self.check_edge_ops = _edge_compile_config._use_edge_ops
204*523fa7a6SAndroid Build Coastguard Worker            self.use_dim_order = not _edge_compile_config._skip_dim_order
205*523fa7a6SAndroid Build Coastguard Worker
206*523fa7a6SAndroid Build Coastguard Worker            self.aten_op_verifier = EXIRATenDialectVerifier(
207*523fa7a6SAndroid Build Coastguard Worker                exception_list=exception_list
208*523fa7a6SAndroid Build Coastguard Worker            )
209*523fa7a6SAndroid Build Coastguard Worker            self.check_valid_aten_op = self.aten_op_verifier.check_valid_op
210*523fa7a6SAndroid Build Coastguard Worker
211*523fa7a6SAndroid Build Coastguard Worker            if self.check_edge_ops:
212*523fa7a6SAndroid Build Coastguard Worker                self.check_valid_op = self.check_valid_edge_op
213*523fa7a6SAndroid Build Coastguard Worker            else:
214*523fa7a6SAndroid Build Coastguard Worker                self.check_valid_op = self.check_valid_aten_op
215*523fa7a6SAndroid Build Coastguard Worker            self._exception_list = exception_list if exception_list else []
216*523fa7a6SAndroid Build Coastguard Worker
217*523fa7a6SAndroid Build Coastguard Worker        def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
218*523fa7a6SAndroid Build Coastguard Worker            return (
219*523fa7a6SAndroid Build Coastguard Worker                torch.fx.GraphModule,
220*523fa7a6SAndroid Build Coastguard Worker                LoweredBackendModule,
221*523fa7a6SAndroid Build Coastguard Worker                torch.Tensor,
222*523fa7a6SAndroid Build Coastguard Worker                torch.ScriptObject,
223*523fa7a6SAndroid Build Coastguard Worker            )
224*523fa7a6SAndroid Build Coastguard Worker
225*523fa7a6SAndroid Build Coastguard Worker        def allowed_op_types(self):
226*523fa7a6SAndroid Build Coastguard Worker            return super().allowed_op_types() + (EdgeOpOverload, types.FunctionType)
227*523fa7a6SAndroid Build Coastguard Worker
228*523fa7a6SAndroid Build Coastguard Worker        def check_valid_edge_op(self, op):
229*523fa7a6SAndroid Build Coastguard Worker            if not self.enable:
230*523fa7a6SAndroid Build Coastguard Worker                return
231*523fa7a6SAndroid Build Coastguard Worker            if (
232*523fa7a6SAndroid Build Coastguard Worker                op
233*523fa7a6SAndroid Build Coastguard Worker                in [
234*523fa7a6SAndroid Build Coastguard Worker                    operator.getitem,
235*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.sym_size.int,
236*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.scalar_tensor.default,
237*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten._assert_async.msg,
238*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten._assert_scalar.default,
239*523fa7a6SAndroid Build Coastguard Worker                ]
240*523fa7a6SAndroid Build Coastguard Worker                + self._exception_list
241*523fa7a6SAndroid Build Coastguard Worker            ):
242*523fa7a6SAndroid Build Coastguard Worker                return
243*523fa7a6SAndroid Build Coastguard Worker
244*523fa7a6SAndroid Build Coastguard Worker            if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):
245*523fa7a6SAndroid Build Coastguard Worker                raise SpecViolationError(
246*523fa7a6SAndroid Build Coastguard Worker                    "Operator {}.{} is not an Edge operator.".format(
247*523fa7a6SAndroid Build Coastguard Worker                        op.__module__, op.__name__
248*523fa7a6SAndroid Build Coastguard Worker                    )
249*523fa7a6SAndroid Build Coastguard Worker                )
250*523fa7a6SAndroid Build Coastguard Worker            if isinstance(op, EdgeOpOverload):
251*523fa7a6SAndroid Build Coastguard Worker                _check_valid_dim_order_ops(op._op, self.use_dim_order)
252*523fa7a6SAndroid Build Coastguard Worker                self.check_valid_aten_op(op._op)
253*523fa7a6SAndroid Build Coastguard Worker
254*523fa7a6SAndroid Build Coastguard Worker            if isinstance(op, types.FunctionType):
255*523fa7a6SAndroid Build Coastguard Worker                assert op.__name__ in ("alloc",)
256*523fa7a6SAndroid Build Coastguard Worker
257*523fa7a6SAndroid Build Coastguard Worker        def check_additional(self, gm: GraphModule) -> None:
258*523fa7a6SAndroid Build Coastguard Worker            if not self.enable:
259*523fa7a6SAndroid Build Coastguard Worker                return
260*523fa7a6SAndroid Build Coastguard Worker            if self.check_edge_ops:
261*523fa7a6SAndroid Build Coastguard Worker                _check_tensors_are_contiguous(gm)
262*523fa7a6SAndroid Build Coastguard Worker                _check_tensor_args_matching_op_allowed_dtype(gm)
263*523fa7a6SAndroid Build Coastguard Worker
264*523fa7a6SAndroid Build Coastguard Worker        def check_valid_op(self, op):
265*523fa7a6SAndroid Build Coastguard Worker            if isinstance(op, OpOverload):
266*523fa7a6SAndroid Build Coastguard Worker                # TODO These special ops should be removable easily.
267*523fa7a6SAndroid Build Coastguard Worker                if op.namespace in (
268*523fa7a6SAndroid Build Coastguard Worker                    "quantized_decomposed",
269*523fa7a6SAndroid Build Coastguard Worker                    "boltnn_nimble",
270*523fa7a6SAndroid Build Coastguard Worker                    "nimble",
271*523fa7a6SAndroid Build Coastguard Worker                    "quantized",
272*523fa7a6SAndroid Build Coastguard Worker                    "dim_order_ops",
273*523fa7a6SAndroid Build Coastguard Worker                ) or op in (
274*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.mkldnn_rnn_layer.default,
275*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten._upsample_bilinear2d_aa.default,
276*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.quantize_per_tensor.default,
277*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.dequantize.self,
278*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.max.default,
279*523fa7a6SAndroid Build Coastguard Worker                    torch.ops.aten.full_like.default,  # TODO(T183507359)
280*523fa7a6SAndroid Build Coastguard Worker                ):
281*523fa7a6SAndroid Build Coastguard Worker                    return
282*523fa7a6SAndroid Build Coastguard Worker                if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
283*523fa7a6SAndroid Build Coastguard Worker                    # NOTE(qihan): whether view_copy operators are marked as canonical is still under
284*523fa7a6SAndroid Build Coastguard Worker                    #            discussion.
285*523fa7a6SAndroid Build Coastguard Worker                    raise SpecViolationError(
286*523fa7a6SAndroid Build Coastguard Worker                        f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
287*523fa7a6SAndroid Build Coastguard Worker                    )
288*523fa7a6SAndroid Build Coastguard Worker
289*523fa7a6SAndroid Build Coastguard Worker        def is_valid(self, gm: GraphModule) -> bool:
290*523fa7a6SAndroid Build Coastguard Worker            try:
291*523fa7a6SAndroid Build Coastguard Worker                self(gm)
292*523fa7a6SAndroid Build Coastguard Worker                return True
293*523fa7a6SAndroid Build Coastguard Worker            except SpecViolationError:
294*523fa7a6SAndroid Build Coastguard Worker                return False
295*523fa7a6SAndroid Build Coastguard Worker
296*523fa7a6SAndroid Build Coastguard Worker        def __call__(self, ep_or_gm):
297*523fa7a6SAndroid Build Coastguard Worker            if not self.enable:
298*523fa7a6SAndroid Build Coastguard Worker                return
299*523fa7a6SAndroid Build Coastguard Worker            gm = ep_or_gm
300*523fa7a6SAndroid Build Coastguard Worker            if isinstance(gm, ExportedProgram):
301*523fa7a6SAndroid Build Coastguard Worker                gm = ep_or_gm.graph_module
302*523fa7a6SAndroid Build Coastguard Worker            if hasattr(self, "_check_graph_module"):
303*523fa7a6SAndroid Build Coastguard Worker                return self._check_graph_module(gm)
304*523fa7a6SAndroid Build Coastguard Worker            elif hasattr(self, "check_valid"):
305*523fa7a6SAndroid Build Coastguard Worker                return self.check_valid(gm)
306*523fa7a6SAndroid Build Coastguard Worker            else:
307*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError("")
308*523fa7a6SAndroid Build Coastguard Worker
309*523fa7a6SAndroid Build Coastguard Worker    ret = _EXIREdgeDialectVerifier
310*523fa7a6SAndroid Build Coastguard Worker    if not class_only:
311*523fa7a6SAndroid Build Coastguard Worker        ret = ret()
312*523fa7a6SAndroid Build Coastguard Worker    return ret
313*523fa7a6SAndroid Build Coastguard Worker
314*523fa7a6SAndroid Build Coastguard Worker
315*523fa7a6SAndroid Build Coastguard WorkerEXIREdgeDialectVerifier()
316