1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport itertools 8*523fa7a6SAndroid Build Coastguard Workerimport operator 9*523fa7a6SAndroid Build Coastguard Workerimport types 10*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import nullcontext 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, List, Optional, Tuple, Type 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.capture._config import EdgeCompileConfig 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import LoweredBackendModule 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.arg_validator import ( 19*523fa7a6SAndroid Build Coastguard Worker EdgeOpArgValidator, 20*523fa7a6SAndroid Build Coastguard Worker RunHigherOrderOperatorError, 21*523fa7a6SAndroid Build Coastguard Worker) 22*523fa7a6SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher 23*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import _detect_fake_mode_from_gm 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.verifier import SpecViolationError, Verifier 26*523fa7a6SAndroid Build Coastguard Workerfrom torch._ops import OpOverload 27*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses import FakeTensor 28*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram 29*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import GraphModule 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard WorkerALLOWED_META_KEYS = {"spec", "stack_trace"} 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Workerdef _check_tensors_are_contiguous(gm: GraphModule) -> None: 36*523fa7a6SAndroid Build Coastguard Worker # Tensors be of contiguous format 37*523fa7a6SAndroid Build Coastguard Worker for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): 38*523fa7a6SAndroid Build Coastguard Worker if isinstance(param, torch.Tensor): 39*523fa7a6SAndroid Build Coastguard Worker if not param.is_contiguous(): 40*523fa7a6SAndroid Build Coastguard Worker raise SpecViolationError( 41*523fa7a6SAndroid Build Coastguard Worker f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" 42*523fa7a6SAndroid Build Coastguard Worker ) 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Workerdef _check_valid_dim_order_ops(op, use_dim_order) -> None: 46*523fa7a6SAndroid Build Coastguard Worker if use_dim_order: 47*523fa7a6SAndroid Build Coastguard Worker if op in (torch.ops.aten._to_copy.default,): 48*523fa7a6SAndroid Build Coastguard Worker raise SpecViolationError(f"{op} should not be used in dim_order mode") 49*523fa7a6SAndroid Build Coastguard Worker else: # not using dim_order 50*523fa7a6SAndroid Build Coastguard Worker if op.namespace in ("dim_order_ops",): 51*523fa7a6SAndroid Build Coastguard Worker raise SpecViolationError(f"{op} should not be used in non-dim_order mode") 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Workerclass EXIRATenDialectVerifierBase(Verifier): 55*523fa7a6SAndroid Build Coastguard Worker dialect = "OLD_EXIR_ATEN_DISABLED" 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: 58*523fa7a6SAndroid Build Coastguard Worker return ( 59*523fa7a6SAndroid Build Coastguard Worker torch.fx.GraphModule, 60*523fa7a6SAndroid Build Coastguard Worker LoweredBackendModule, 61*523fa7a6SAndroid Build Coastguard Worker torch.Tensor, 62*523fa7a6SAndroid Build Coastguard Worker torch.ScriptObject, 63*523fa7a6SAndroid Build Coastguard Worker ) 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker def allowed_op_types(self): 66*523fa7a6SAndroid Build Coastguard Worker return super().allowed_op_types() + (torch._ops.OpOverloadPacket,) 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker def __call__(self, *args, **kwargs): 69*523fa7a6SAndroid Build Coastguard Worker if hasattr(self, "_check_graph_module"): 70*523fa7a6SAndroid Build Coastguard Worker return self._check_graph_module(*args, **kwargs) 71*523fa7a6SAndroid Build Coastguard Worker elif hasattr(self, "check_valid"): 72*523fa7a6SAndroid Build Coastguard Worker return self.check_valid(*args, **kwargs) 73*523fa7a6SAndroid Build Coastguard Worker else: 74*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError("") 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Workerdef EXIRATenDialectVerifier( # noqa: C901 78*523fa7a6SAndroid Build Coastguard Worker edge_compile_config: Optional[EdgeCompileConfig] = None, 79*523fa7a6SAndroid Build Coastguard Worker class_only: bool = False, 80*523fa7a6SAndroid Build Coastguard Worker exception_list: Optional[List[torch._ops.OpOverload]] = None, 81*523fa7a6SAndroid Build Coastguard Worker): 82*523fa7a6SAndroid Build Coastguard Worker """ 83*523fa7a6SAndroid Build Coastguard Worker Returns a verifier class that runs ATen dialect specific checks on the graph module. 84*523fa7a6SAndroid Build Coastguard Worker """ 85*523fa7a6SAndroid Build Coastguard Worker # merge the exception list from edge_compile_config and exception_list 86*523fa7a6SAndroid Build Coastguard Worker if edge_compile_config and edge_compile_config._core_aten_ops_exception_list: 87*523fa7a6SAndroid Build Coastguard Worker exception_list = edge_compile_config._core_aten_ops_exception_list + ( 88*523fa7a6SAndroid Build Coastguard Worker exception_list or [] 89*523fa7a6SAndroid Build Coastguard Worker ) 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase): 92*523fa7a6SAndroid Build Coastguard Worker dialect = "OLD_EXIR_ATEN" 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 95*523fa7a6SAndroid Build Coastguard Worker super().__init__() 96*523fa7a6SAndroid Build Coastguard Worker # Note: here we are using the exception list passed from EXIRATenDialectVerifier function! 97*523fa7a6SAndroid Build Coastguard Worker self._exception_list = exception_list if exception_list else [] 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker def _get_exception_list(self) -> List[torch._ops.OpOverload]: 100*523fa7a6SAndroid Build Coastguard Worker exception_list = [ 101*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.mkldnn_rnn_layer.default, 102*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten._upsample_bilinear2d_aa.default, 103*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.quantize_per_tensor.default, 104*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.dequantize.self, 105*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.max.default, # TODO(T188268054) 106*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.min.default, # TODO(T188268054) 107*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.full_like.default, # TODO(T183507359) 108*523fa7a6SAndroid Build Coastguard Worker ] 109*523fa7a6SAndroid Build Coastguard Worker exception_list += self._exception_list 110*523fa7a6SAndroid Build Coastguard Worker 111*523fa7a6SAndroid Build Coastguard Worker return exception_list 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker def check_valid_op(self, op): 114*523fa7a6SAndroid Build Coastguard Worker if isinstance(op, OpOverload): 115*523fa7a6SAndroid Build Coastguard Worker # TODO These special ops should be removable easily. 116*523fa7a6SAndroid Build Coastguard Worker if op.namespace != "aten" or op in self._get_exception_list(): 117*523fa7a6SAndroid Build Coastguard Worker return 118*523fa7a6SAndroid Build Coastguard Worker if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags: 119*523fa7a6SAndroid Build Coastguard Worker # NOTE(qihan): whether view_copy operators are marked as canonical is still under 120*523fa7a6SAndroid Build Coastguard Worker # discussion. 121*523fa7a6SAndroid Build Coastguard Worker raise SpecViolationError( 122*523fa7a6SAndroid Build Coastguard Worker f"Operator {op.__module__}.{op.__name__} is not Aten Canonical." 123*523fa7a6SAndroid Build Coastguard Worker ) 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker ret = _EXIRATenDialectVerifier 126*523fa7a6SAndroid Build Coastguard Worker if not class_only: 127*523fa7a6SAndroid Build Coastguard Worker ret = ret() 128*523fa7a6SAndroid Build Coastguard Worker return ret 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker 131*523fa7a6SAndroid Build Coastguard Workerdef get_aten_verifier(config: EdgeCompileConfig): 132*523fa7a6SAndroid Build Coastguard Worker return ( 133*523fa7a6SAndroid Build Coastguard Worker EXIRATenDialectVerifier( 134*523fa7a6SAndroid Build Coastguard Worker class_only=True, exception_list=config._core_aten_ops_exception_list 135*523fa7a6SAndroid Build Coastguard Worker ) 136*523fa7a6SAndroid Build Coastguard Worker if config._check_ir_validity 137*523fa7a6SAndroid Build Coastguard Worker else EXIRATenDialectVerifierBase 138*523fa7a6SAndroid Build Coastguard Worker ) 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Workerdef _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]: 142*523fa7a6SAndroid Build Coastguard Worker def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: 143*523fa7a6SAndroid Build Coastguard Worker if "val" in node.meta: 144*523fa7a6SAndroid Build Coastguard Worker return node.meta["val"] 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Worker if len(node.users) == 0: 147*523fa7a6SAndroid Build Coastguard Worker return None 148*523fa7a6SAndroid Build Coastguard Worker 149*523fa7a6SAndroid Build Coastguard Worker # TODO(ycao): `val` should always exist after we enable shape environment 150*523fa7a6SAndroid Build Coastguard Worker # serialization and deserialization. 151*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 152*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.VIOLATION_OF_SPEC, 153*523fa7a6SAndroid Build Coastguard Worker f"Cannot construct an input for graph module: {graph_module}.", 154*523fa7a6SAndroid Build Coastguard Worker ) 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker return [ 157*523fa7a6SAndroid Build Coastguard Worker extract_input(node) 158*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes 159*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" 160*523fa7a6SAndroid Build Coastguard Worker ] 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Worker 163*523fa7a6SAndroid Build Coastguard Workerdef _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None: 164*523fa7a6SAndroid Build Coastguard Worker validator = EdgeOpArgValidator(gm) 165*523fa7a6SAndroid Build Coastguard Worker inputs = _get_inputs(gm) 166*523fa7a6SAndroid Build Coastguard Worker fake_mode = _detect_fake_mode_from_gm(gm) or nullcontext() 167*523fa7a6SAndroid Build Coastguard Worker try: 168*523fa7a6SAndroid Build Coastguard Worker with enable_python_dispatcher(), fake_mode: 169*523fa7a6SAndroid Build Coastguard Worker validator.run(*inputs) 170*523fa7a6SAndroid Build Coastguard Worker except RunHigherOrderOperatorError: 171*523fa7a6SAndroid Build Coastguard Worker # NB: ignore higher order operator in the graph. 172*523fa7a6SAndroid Build Coastguard Worker # If we lower a graph module to delegate and then compose it with some other graph module, retrace it, 173*523fa7a6SAndroid Build Coastguard Worker # if we also turn on edge ops and validator (_check_ir_validity=True), we will run 174*523fa7a6SAndroid Build Coastguard Worker # into RunHigherOrderOperatorError. The only thing we can do right now is to ignore this error, since 175*523fa7a6SAndroid Build Coastguard Worker # by definition it's still a valid Edge dialect. This is not ideal because it ignores possible invalidity 176*523fa7a6SAndroid Build Coastguard Worker # later in the graph. 177*523fa7a6SAndroid Build Coastguard Worker return 178*523fa7a6SAndroid Build Coastguard Worker 179*523fa7a6SAndroid Build Coastguard Worker if validator.violating_ops: 180*523fa7a6SAndroid Build Coastguard Worker raise SpecViolationError( 181*523fa7a6SAndroid Build Coastguard Worker f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}" 182*523fa7a6SAndroid Build Coastguard Worker ) 183*523fa7a6SAndroid Build Coastguard Worker 184*523fa7a6SAndroid Build Coastguard Worker 185*523fa7a6SAndroid Build Coastguard Workerdef EXIREdgeDialectVerifier( # noqa: C901 186*523fa7a6SAndroid Build Coastguard Worker edge_compile_config: Optional[EdgeCompileConfig] = None, 187*523fa7a6SAndroid Build Coastguard Worker class_only: bool = False, 188*523fa7a6SAndroid Build Coastguard Worker exception_list: Optional[List[torch._ops.OpOverload]] = None, 189*523fa7a6SAndroid Build Coastguard Worker): 190*523fa7a6SAndroid Build Coastguard Worker # merge the exception list from edge_compile_config and exception_list 191*523fa7a6SAndroid Build Coastguard Worker if edge_compile_config and edge_compile_config._core_aten_ops_exception_list: 192*523fa7a6SAndroid Build Coastguard Worker exception_list = edge_compile_config._core_aten_ops_exception_list + ( 193*523fa7a6SAndroid Build Coastguard Worker exception_list or [] 194*523fa7a6SAndroid Build Coastguard Worker ) 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker class _EXIREdgeDialectVerifier(Verifier): 197*523fa7a6SAndroid Build Coastguard Worker dialect = "EDGE" 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 200*523fa7a6SAndroid Build Coastguard Worker _edge_compile_config = edge_compile_config or EdgeCompileConfig() 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard Worker self.enable = _edge_compile_config._check_ir_validity 203*523fa7a6SAndroid Build Coastguard Worker self.check_edge_ops = _edge_compile_config._use_edge_ops 204*523fa7a6SAndroid Build Coastguard Worker self.use_dim_order = not _edge_compile_config._skip_dim_order 205*523fa7a6SAndroid Build Coastguard Worker 206*523fa7a6SAndroid Build Coastguard Worker self.aten_op_verifier = EXIRATenDialectVerifier( 207*523fa7a6SAndroid Build Coastguard Worker exception_list=exception_list 208*523fa7a6SAndroid Build Coastguard Worker ) 209*523fa7a6SAndroid Build Coastguard Worker self.check_valid_aten_op = self.aten_op_verifier.check_valid_op 210*523fa7a6SAndroid Build Coastguard Worker 211*523fa7a6SAndroid Build Coastguard Worker if self.check_edge_ops: 212*523fa7a6SAndroid Build Coastguard Worker self.check_valid_op = self.check_valid_edge_op 213*523fa7a6SAndroid Build Coastguard Worker else: 214*523fa7a6SAndroid Build Coastguard Worker self.check_valid_op = self.check_valid_aten_op 215*523fa7a6SAndroid Build Coastguard Worker self._exception_list = exception_list if exception_list else [] 216*523fa7a6SAndroid Build Coastguard Worker 217*523fa7a6SAndroid Build Coastguard Worker def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: 218*523fa7a6SAndroid Build Coastguard Worker return ( 219*523fa7a6SAndroid Build Coastguard Worker torch.fx.GraphModule, 220*523fa7a6SAndroid Build Coastguard Worker LoweredBackendModule, 221*523fa7a6SAndroid Build Coastguard Worker torch.Tensor, 222*523fa7a6SAndroid Build Coastguard Worker torch.ScriptObject, 223*523fa7a6SAndroid Build Coastguard Worker ) 224*523fa7a6SAndroid Build Coastguard Worker 225*523fa7a6SAndroid Build Coastguard Worker def allowed_op_types(self): 226*523fa7a6SAndroid Build Coastguard Worker return super().allowed_op_types() + (EdgeOpOverload, types.FunctionType) 227*523fa7a6SAndroid Build Coastguard Worker 228*523fa7a6SAndroid Build Coastguard Worker def check_valid_edge_op(self, op): 229*523fa7a6SAndroid Build Coastguard Worker if not self.enable: 230*523fa7a6SAndroid Build Coastguard Worker return 231*523fa7a6SAndroid Build Coastguard Worker if ( 232*523fa7a6SAndroid Build Coastguard Worker op 233*523fa7a6SAndroid Build Coastguard Worker in [ 234*523fa7a6SAndroid Build Coastguard Worker operator.getitem, 235*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.sym_size.int, 236*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.scalar_tensor.default, 237*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten._assert_async.msg, 238*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten._assert_scalar.default, 239*523fa7a6SAndroid Build Coastguard Worker ] 240*523fa7a6SAndroid Build Coastguard Worker + self._exception_list 241*523fa7a6SAndroid Build Coastguard Worker ): 242*523fa7a6SAndroid Build Coastguard Worker return 243*523fa7a6SAndroid Build Coastguard Worker 244*523fa7a6SAndroid Build Coastguard Worker if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload): 245*523fa7a6SAndroid Build Coastguard Worker raise SpecViolationError( 246*523fa7a6SAndroid Build Coastguard Worker "Operator {}.{} is not an Edge operator.".format( 247*523fa7a6SAndroid Build Coastguard Worker op.__module__, op.__name__ 248*523fa7a6SAndroid Build Coastguard Worker ) 249*523fa7a6SAndroid Build Coastguard Worker ) 250*523fa7a6SAndroid Build Coastguard Worker if isinstance(op, EdgeOpOverload): 251*523fa7a6SAndroid Build Coastguard Worker _check_valid_dim_order_ops(op._op, self.use_dim_order) 252*523fa7a6SAndroid Build Coastguard Worker self.check_valid_aten_op(op._op) 253*523fa7a6SAndroid Build Coastguard Worker 254*523fa7a6SAndroid Build Coastguard Worker if isinstance(op, types.FunctionType): 255*523fa7a6SAndroid Build Coastguard Worker assert op.__name__ in ("alloc",) 256*523fa7a6SAndroid Build Coastguard Worker 257*523fa7a6SAndroid Build Coastguard Worker def check_additional(self, gm: GraphModule) -> None: 258*523fa7a6SAndroid Build Coastguard Worker if not self.enable: 259*523fa7a6SAndroid Build Coastguard Worker return 260*523fa7a6SAndroid Build Coastguard Worker if self.check_edge_ops: 261*523fa7a6SAndroid Build Coastguard Worker _check_tensors_are_contiguous(gm) 262*523fa7a6SAndroid Build Coastguard Worker _check_tensor_args_matching_op_allowed_dtype(gm) 263*523fa7a6SAndroid Build Coastguard Worker 264*523fa7a6SAndroid Build Coastguard Worker def check_valid_op(self, op): 265*523fa7a6SAndroid Build Coastguard Worker if isinstance(op, OpOverload): 266*523fa7a6SAndroid Build Coastguard Worker # TODO These special ops should be removable easily. 267*523fa7a6SAndroid Build Coastguard Worker if op.namespace in ( 268*523fa7a6SAndroid Build Coastguard Worker "quantized_decomposed", 269*523fa7a6SAndroid Build Coastguard Worker "boltnn_nimble", 270*523fa7a6SAndroid Build Coastguard Worker "nimble", 271*523fa7a6SAndroid Build Coastguard Worker "quantized", 272*523fa7a6SAndroid Build Coastguard Worker "dim_order_ops", 273*523fa7a6SAndroid Build Coastguard Worker ) or op in ( 274*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.mkldnn_rnn_layer.default, 275*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten._upsample_bilinear2d_aa.default, 276*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.quantize_per_tensor.default, 277*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.dequantize.self, 278*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.max.default, 279*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.full_like.default, # TODO(T183507359) 280*523fa7a6SAndroid Build Coastguard Worker ): 281*523fa7a6SAndroid Build Coastguard Worker return 282*523fa7a6SAndroid Build Coastguard Worker if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags: 283*523fa7a6SAndroid Build Coastguard Worker # NOTE(qihan): whether view_copy operators are marked as canonical is still under 284*523fa7a6SAndroid Build Coastguard Worker # discussion. 285*523fa7a6SAndroid Build Coastguard Worker raise SpecViolationError( 286*523fa7a6SAndroid Build Coastguard Worker f"Operator {op.__module__}.{op.__name__} is not Aten Canonical." 287*523fa7a6SAndroid Build Coastguard Worker ) 288*523fa7a6SAndroid Build Coastguard Worker 289*523fa7a6SAndroid Build Coastguard Worker def is_valid(self, gm: GraphModule) -> bool: 290*523fa7a6SAndroid Build Coastguard Worker try: 291*523fa7a6SAndroid Build Coastguard Worker self(gm) 292*523fa7a6SAndroid Build Coastguard Worker return True 293*523fa7a6SAndroid Build Coastguard Worker except SpecViolationError: 294*523fa7a6SAndroid Build Coastguard Worker return False 295*523fa7a6SAndroid Build Coastguard Worker 296*523fa7a6SAndroid Build Coastguard Worker def __call__(self, ep_or_gm): 297*523fa7a6SAndroid Build Coastguard Worker if not self.enable: 298*523fa7a6SAndroid Build Coastguard Worker return 299*523fa7a6SAndroid Build Coastguard Worker gm = ep_or_gm 300*523fa7a6SAndroid Build Coastguard Worker if isinstance(gm, ExportedProgram): 301*523fa7a6SAndroid Build Coastguard Worker gm = ep_or_gm.graph_module 302*523fa7a6SAndroid Build Coastguard Worker if hasattr(self, "_check_graph_module"): 303*523fa7a6SAndroid Build Coastguard Worker return self._check_graph_module(gm) 304*523fa7a6SAndroid Build Coastguard Worker elif hasattr(self, "check_valid"): 305*523fa7a6SAndroid Build Coastguard Worker return self.check_valid(gm) 306*523fa7a6SAndroid Build Coastguard Worker else: 307*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError("") 308*523fa7a6SAndroid Build Coastguard Worker 309*523fa7a6SAndroid Build Coastguard Worker ret = _EXIREdgeDialectVerifier 310*523fa7a6SAndroid Build Coastguard Worker if not class_only: 311*523fa7a6SAndroid Build Coastguard Worker ret = ret() 312*523fa7a6SAndroid Build Coastguard Worker return ret 313*523fa7a6SAndroid Build Coastguard Worker 314*523fa7a6SAndroid Build Coastguard Worker 315*523fa7a6SAndroid Build Coastguard WorkerEXIREdgeDialectVerifier() 316