1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport operator 10*523fa7a6SAndroid Build Coastguard Workerimport traceback 11*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import nullcontext 12*523fa7a6SAndroid Build Coastguard Workerfrom typing import ( 13*523fa7a6SAndroid Build Coastguard Worker Any, 14*523fa7a6SAndroid Build Coastguard Worker Callable, 15*523fa7a6SAndroid Build Coastguard Worker Dict, 16*523fa7a6SAndroid Build Coastguard Worker List, 17*523fa7a6SAndroid Build Coastguard Worker MutableMapping, 18*523fa7a6SAndroid Build Coastguard Worker Optional, 19*523fa7a6SAndroid Build Coastguard Worker Protocol, 20*523fa7a6SAndroid Build Coastguard Worker runtime_checkable, 21*523fa7a6SAndroid Build Coastguard Worker Set, 22*523fa7a6SAndroid Build Coastguard Worker Tuple, 23*523fa7a6SAndroid Build Coastguard Worker TypeVar, 24*523fa7a6SAndroid Build Coastguard Worker Union, 25*523fa7a6SAndroid Build Coastguard Worker) 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Workerimport torch 28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import memory 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate, is_lowered_module 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload 33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType 34*523fa7a6SAndroid Build Coastguard Workerfrom torch import fx 35*523fa7a6SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher 36*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException 37*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor 38*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode 39*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import traceback as fx_traceback 40*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import PythonKeyTracer 41*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.graph import CodeGen 42*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.pass_base import PassBase, PassResult 43*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata 44*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 45*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import PyTree 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard WorkerFn = Callable[..., Any] # pyre-ignore 48*523fa7a6SAndroid Build Coastguard WorkerArgument = Any # pyre-ignore 49*523fa7a6SAndroid Build Coastguard WorkerValue = Any # pyre-ignore 50*523fa7a6SAndroid Build Coastguard WorkerNodeMetadataValue = Any # pyre-ignore 51*523fa7a6SAndroid Build Coastguard WorkerK = TypeVar("K") 52*523fa7a6SAndroid Build Coastguard WorkerPassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Worker_TORCH_SYM_OPS: Set[Any] = { # pyre-ignore 56*523fa7a6SAndroid Build Coastguard Worker torch.sym_int, 57*523fa7a6SAndroid Build Coastguard Worker torch.sym_float, 58*523fa7a6SAndroid Build Coastguard Worker torch.sym_ite, 59*523fa7a6SAndroid Build Coastguard Worker torch.sym_max, 60*523fa7a6SAndroid Build Coastguard Worker torch.sym_min, 61*523fa7a6SAndroid Build Coastguard Worker torch.sym_not, 62*523fa7a6SAndroid Build Coastguard Worker torch.sym_sqrt, 63*523fa7a6SAndroid Build Coastguard Worker} 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard WorkerPROTECTED_KEYS: Set[str] = { 67*523fa7a6SAndroid Build Coastguard Worker "val", 68*523fa7a6SAndroid Build Coastguard Worker "stack_trace", 69*523fa7a6SAndroid Build Coastguard Worker "nn_module_stack", 70*523fa7a6SAndroid Build Coastguard Worker "debug_handle", 71*523fa7a6SAndroid Build Coastguard Worker "tensor_meta", 72*523fa7a6SAndroid Build Coastguard Worker} 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Workerdef _unstack_pytree(xs) -> List[PyTree]: # pyre-ignore 76*523fa7a6SAndroid Build Coastguard Worker flat_xs, inspec = pytree.tree_flatten(xs) 77*523fa7a6SAndroid Build Coastguard Worker if not all(isinstance(xs, torch.Tensor) for xs in flat_xs): 78*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}") 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs): 81*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 82*523fa7a6SAndroid Build Coastguard Worker f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}" 83*523fa7a6SAndroid Build Coastguard Worker ) 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard Worker ctx = ( 86*523fa7a6SAndroid Build Coastguard Worker FunctionalTensorMode 87*523fa7a6SAndroid Build Coastguard Worker if any(isinstance(x, FunctionalTensor) for x in flat_xs) 88*523fa7a6SAndroid Build Coastguard Worker else nullcontext 89*523fa7a6SAndroid Build Coastguard Worker ) 90*523fa7a6SAndroid Build Coastguard Worker with ctx(): 91*523fa7a6SAndroid Build Coastguard Worker a = zip(*flat_xs) 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker pytrees = [] 94*523fa7a6SAndroid Build Coastguard Worker for tuple in a: 95*523fa7a6SAndroid Build Coastguard Worker pytrees.append(pytree.tree_unflatten(tuple, inspec)) 96*523fa7a6SAndroid Build Coastguard Worker return pytrees 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Workerclass NodeMetadata: 100*523fa7a6SAndroid Build Coastguard Worker def __init__(self, data: Dict[str, Any]) -> None: 101*523fa7a6SAndroid Build Coastguard Worker self.data: Dict[str, Any] = data.copy() 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker def __getitem__(self, key: str) -> NodeMetadataValue: 104*523fa7a6SAndroid Build Coastguard Worker return self.data[key] 105*523fa7a6SAndroid Build Coastguard Worker 106*523fa7a6SAndroid Build Coastguard Worker def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue: 107*523fa7a6SAndroid Build Coastguard Worker if key in PROTECTED_KEYS: 108*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Could not override node key: {key}") 109*523fa7a6SAndroid Build Coastguard Worker self.data[key] = value 110*523fa7a6SAndroid Build Coastguard Worker 111*523fa7a6SAndroid Build Coastguard Worker def __contains__(self, key: str) -> bool: 112*523fa7a6SAndroid Build Coastguard Worker return key in self.data 113*523fa7a6SAndroid Build Coastguard Worker 114*523fa7a6SAndroid Build Coastguard Worker def copy(self) -> "NodeMetadata": 115*523fa7a6SAndroid Build Coastguard Worker return NodeMetadata(self.data.copy()) 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Workerclass ProxyValue: 119*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 120*523fa7a6SAndroid Build Coastguard Worker def __init__(self, data, proxy: Union[torch.fx.Proxy, torch.fx.Node]): 121*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 122*523fa7a6SAndroid Build Coastguard Worker self.data = data 123*523fa7a6SAndroid Build Coastguard Worker self.proxy_or_node = proxy 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker @property 126*523fa7a6SAndroid Build Coastguard Worker def node(self) -> torch.fx.Node: 127*523fa7a6SAndroid Build Coastguard Worker if isinstance(self.proxy_or_node, torch.fx.Node): 128*523fa7a6SAndroid Build Coastguard Worker return self.proxy_or_node 129*523fa7a6SAndroid Build Coastguard Worker assert isinstance(self.proxy_or_node, torch.fx.Proxy) 130*523fa7a6SAndroid Build Coastguard Worker return self.proxy_or_node.node 131*523fa7a6SAndroid Build Coastguard Worker 132*523fa7a6SAndroid Build Coastguard Worker @property 133*523fa7a6SAndroid Build Coastguard Worker def proxy(self) -> torch.fx.Proxy: 134*523fa7a6SAndroid Build Coastguard Worker if not isinstance(self.proxy_or_node, torch.fx.Proxy): 135*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 136*523fa7a6SAndroid Build Coastguard Worker f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}" 137*523fa7a6SAndroid Build Coastguard Worker ) 138*523fa7a6SAndroid Build Coastguard Worker return self.proxy_or_node 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker def to_tensor(self) -> torch.Tensor: 141*523fa7a6SAndroid Build Coastguard Worker assert isinstance(self.data, torch.Tensor) 142*523fa7a6SAndroid Build Coastguard Worker return self.data 143*523fa7a6SAndroid Build Coastguard Worker 144*523fa7a6SAndroid Build Coastguard Worker def is_tensor(self) -> bool: 145*523fa7a6SAndroid Build Coastguard Worker return isinstance(self.data, torch.Tensor) 146*523fa7a6SAndroid Build Coastguard Worker 147*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 148*523fa7a6SAndroid Build Coastguard Worker def __iter__(self): 149*523fa7a6SAndroid Build Coastguard Worker yield from self.data 150*523fa7a6SAndroid Build Coastguard Worker 151*523fa7a6SAndroid Build Coastguard Worker def __bool__(self) -> bool: 152*523fa7a6SAndroid Build Coastguard Worker return bool(self.data) 153*523fa7a6SAndroid Build Coastguard Worker 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Workerclass ExportPassBaseError(RuntimeError): 156*523fa7a6SAndroid Build Coastguard Worker pass 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Workerclass _ExportPassBase(PassBase): 160*523fa7a6SAndroid Build Coastguard Worker """ 161*523fa7a6SAndroid Build Coastguard Worker Interpreter-based pass class to help users maintain the IR spec while writing 162*523fa7a6SAndroid Build Coastguard Worker transformations. 163*523fa7a6SAndroid Build Coastguard Worker """ 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Worker @staticmethod 166*523fa7a6SAndroid Build Coastguard Worker def _create_dummy_node_metadata() -> NodeMetadata: 167*523fa7a6SAndroid Build Coastguard Worker return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) 168*523fa7a6SAndroid Build Coastguard Worker 169*523fa7a6SAndroid Build Coastguard Worker class ExportTracer(PythonKeyTracer): 170*523fa7a6SAndroid Build Coastguard Worker def __init__(self, callback: "_ExportPassBase", codegen: CodeGen) -> None: 171*523fa7a6SAndroid Build Coastguard Worker super().__init__() 172*523fa7a6SAndroid Build Coastguard Worker self.callback = callback 173*523fa7a6SAndroid Build Coastguard Worker self.root = torch.nn.Module() 174*523fa7a6SAndroid Build Coastguard Worker self.graph = torch.fx.Graph() 175*523fa7a6SAndroid Build Coastguard Worker self.graph.set_codegen(codegen) 176*523fa7a6SAndroid Build Coastguard Worker self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment] 177*523fa7a6SAndroid Build Coastguard Worker self.fake_tensor_mode: Optional[FakeTensorMode] = None 178*523fa7a6SAndroid Build Coastguard Worker self.submodules: Dict[torch.nn.Module, str] = {} 179*523fa7a6SAndroid Build Coastguard Worker 180*523fa7a6SAndroid Build Coastguard Worker def trace(self) -> None: # pyre-fixme[14,15] 181*523fa7a6SAndroid Build Coastguard Worker raise ExportPassBaseError("ExportTracer doesn't support trace().") 182*523fa7a6SAndroid Build Coastguard Worker 183*523fa7a6SAndroid Build Coastguard Worker def create_arg(self, a: Argument) -> torch.fx.Node: 184*523fa7a6SAndroid Build Coastguard Worker if isinstance(a, torch.nn.Module): 185*523fa7a6SAndroid Build Coastguard Worker if a not in self.submodules: 186*523fa7a6SAndroid Build Coastguard Worker name_submodule = f"submodule_{len(self.submodules)}" 187*523fa7a6SAndroid Build Coastguard Worker self.root.add_module(name_submodule, a) 188*523fa7a6SAndroid Build Coastguard Worker self.submodules[a] = name_submodule 189*523fa7a6SAndroid Build Coastguard Worker elif isinstance(a, FakeTensor): 190*523fa7a6SAndroid Build Coastguard Worker if not hasattr(a, "constant") or a.constant is None: 191*523fa7a6SAndroid Build Coastguard Worker raise ExportPassBaseError(f"Cannot add {a} to graph.") 192*523fa7a6SAndroid Build Coastguard Worker a = a.constant 193*523fa7a6SAndroid Build Coastguard Worker node = super().create_arg(a) 194*523fa7a6SAndroid Build Coastguard Worker if ( 195*523fa7a6SAndroid Build Coastguard Worker isinstance(a, torch.Tensor) 196*523fa7a6SAndroid Build Coastguard Worker and isinstance(node, torch.fx.Node) 197*523fa7a6SAndroid Build Coastguard Worker and node.op == "get_attr" 198*523fa7a6SAndroid Build Coastguard Worker ): 199*523fa7a6SAndroid Build Coastguard Worker self.set_metadata(node, a) 200*523fa7a6SAndroid Build Coastguard Worker self.callback.on_attr(ProxyValue(a, node)) 201*523fa7a6SAndroid Build Coastguard Worker return node 202*523fa7a6SAndroid Build Coastguard Worker 203*523fa7a6SAndroid Build Coastguard Worker def set_metadata( # noqa: C901 204*523fa7a6SAndroid Build Coastguard Worker self, 205*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 206*523fa7a6SAndroid Build Coastguard Worker value: Argument, 207*523fa7a6SAndroid Build Coastguard Worker ) -> None: 208*523fa7a6SAndroid Build Coastguard Worker # propagate the fake tensor or sym nodes 209*523fa7a6SAndroid Build Coastguard Worker def make_val( 210*523fa7a6SAndroid Build Coastguard Worker x: Argument, 211*523fa7a6SAndroid Build Coastguard Worker ) -> Union[ 212*523fa7a6SAndroid Build Coastguard Worker FakeTensor, 213*523fa7a6SAndroid Build Coastguard Worker torch.SymInt, 214*523fa7a6SAndroid Build Coastguard Worker torch.SymFloat, 215*523fa7a6SAndroid Build Coastguard Worker torch.SymBool, 216*523fa7a6SAndroid Build Coastguard Worker int, 217*523fa7a6SAndroid Build Coastguard Worker float, 218*523fa7a6SAndroid Build Coastguard Worker bool, 219*523fa7a6SAndroid Build Coastguard Worker str, 220*523fa7a6SAndroid Build Coastguard Worker None, 221*523fa7a6SAndroid Build Coastguard Worker ]: 222*523fa7a6SAndroid Build Coastguard Worker if isinstance(x, FakeTensor): 223*523fa7a6SAndroid Build Coastguard Worker return x 224*523fa7a6SAndroid Build Coastguard Worker elif isinstance(x, torch.Tensor): 225*523fa7a6SAndroid Build Coastguard Worker if x.is_quantized: 226*523fa7a6SAndroid Build Coastguard Worker # TODO (tmanlaibaatar) properly support Quantized FakeTensor 227*523fa7a6SAndroid Build Coastguard Worker x = torch.dequantize(x) 228*523fa7a6SAndroid Build Coastguard Worker 229*523fa7a6SAndroid Build Coastguard Worker try: 230*523fa7a6SAndroid Build Coastguard Worker assert self.fake_tensor_mode is not None 231*523fa7a6SAndroid Build Coastguard Worker # TODO we should allocate static shapes 232*523fa7a6SAndroid Build Coastguard Worker # for param/buffer values 233*523fa7a6SAndroid Build Coastguard Worker if isinstance(x, torch.nn.Parameter): 234*523fa7a6SAndroid Build Coastguard Worker fake_tensor = self.fake_tensor_mode.from_tensor( 235*523fa7a6SAndroid Build Coastguard Worker x, static_shapes=True 236*523fa7a6SAndroid Build Coastguard Worker ) 237*523fa7a6SAndroid Build Coastguard Worker else: 238*523fa7a6SAndroid Build Coastguard Worker fake_tensor = self.fake_tensor_mode.from_tensor(x) 239*523fa7a6SAndroid Build Coastguard Worker except UnsupportedFakeTensorException: 240*523fa7a6SAndroid Build Coastguard Worker # TODO: This is just a workaround to get over the 241*523fa7a6SAndroid Build Coastguard Worker # x.as_subclass error 242*523fa7a6SAndroid Build Coastguard Worker print( 243*523fa7a6SAndroid Build Coastguard Worker "Fakeifying a Tensor subclass is not supported \ 244*523fa7a6SAndroid Build Coastguard Worker right now. Instead a TensorMetadata is used." 245*523fa7a6SAndroid Build Coastguard Worker ) 246*523fa7a6SAndroid Build Coastguard Worker fake_tensor = None 247*523fa7a6SAndroid Build Coastguard Worker return fake_tensor 248*523fa7a6SAndroid Build Coastguard Worker elif isinstance( 249*523fa7a6SAndroid Build Coastguard Worker x, 250*523fa7a6SAndroid Build Coastguard Worker ( 251*523fa7a6SAndroid Build Coastguard Worker torch.SymInt, 252*523fa7a6SAndroid Build Coastguard Worker torch.SymFloat, 253*523fa7a6SAndroid Build Coastguard Worker torch.SymBool, 254*523fa7a6SAndroid Build Coastguard Worker int, 255*523fa7a6SAndroid Build Coastguard Worker float, 256*523fa7a6SAndroid Build Coastguard Worker bool, 257*523fa7a6SAndroid Build Coastguard Worker str, 258*523fa7a6SAndroid Build Coastguard Worker ), 259*523fa7a6SAndroid Build Coastguard Worker ): 260*523fa7a6SAndroid Build Coastguard Worker return x 261*523fa7a6SAndroid Build Coastguard Worker else: 262*523fa7a6SAndroid Build Coastguard Worker return None 263*523fa7a6SAndroid Build Coastguard Worker 264*523fa7a6SAndroid Build Coastguard Worker node.meta["val"] = pytree.tree_map(make_val, value) 265*523fa7a6SAndroid Build Coastguard Worker 266*523fa7a6SAndroid Build Coastguard Worker # Set the tensor_metadata for values that do not have a corresponding FakeTensor 267*523fa7a6SAndroid Build Coastguard Worker def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]: 268*523fa7a6SAndroid Build Coastguard Worker if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor): 269*523fa7a6SAndroid Build Coastguard Worker if x.is_quantized: 270*523fa7a6SAndroid Build Coastguard Worker # TODO (tmanlaibaatar) properly support Quantized FakeTensor 271*523fa7a6SAndroid Build Coastguard Worker x = torch.dequantize(x) 272*523fa7a6SAndroid Build Coastguard Worker 273*523fa7a6SAndroid Build Coastguard Worker try: 274*523fa7a6SAndroid Build Coastguard Worker assert self.fake_tensor_mode is not None 275*523fa7a6SAndroid Build Coastguard Worker _ = self.fake_tensor_mode.from_tensor(x) 276*523fa7a6SAndroid Build Coastguard Worker tensor_meta = None 277*523fa7a6SAndroid Build Coastguard Worker except UnsupportedFakeTensorException: 278*523fa7a6SAndroid Build Coastguard Worker # TODO: This is just a workaround to get over the 279*523fa7a6SAndroid Build Coastguard Worker # x.as_subclass error 280*523fa7a6SAndroid Build Coastguard Worker tensor_meta = _extract_tensor_metadata(x) 281*523fa7a6SAndroid Build Coastguard Worker return tensor_meta 282*523fa7a6SAndroid Build Coastguard Worker else: 283*523fa7a6SAndroid Build Coastguard Worker return None 284*523fa7a6SAndroid Build Coastguard Worker 285*523fa7a6SAndroid Build Coastguard Worker node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) 286*523fa7a6SAndroid Build Coastguard Worker 287*523fa7a6SAndroid Build Coastguard Worker class ExportInterpreter(fx.Interpreter): 288*523fa7a6SAndroid Build Coastguard Worker def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None: 289*523fa7a6SAndroid Build Coastguard Worker super().__init__(gm) 290*523fa7a6SAndroid Build Coastguard Worker self.callback = callback 291*523fa7a6SAndroid Build Coastguard Worker self.node: torch.fx.Node = next(iter(gm.graph.nodes)) 292*523fa7a6SAndroid Build Coastguard Worker 293*523fa7a6SAndroid Build Coastguard Worker def placeholder( # pyre-fixme[14] 294*523fa7a6SAndroid Build Coastguard Worker self, 295*523fa7a6SAndroid Build Coastguard Worker target: str, 296*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 297*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 298*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 299*523fa7a6SAndroid Build Coastguard Worker arg = super().placeholder(target, args, kwargs) 300*523fa7a6SAndroid Build Coastguard Worker return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta)) 301*523fa7a6SAndroid Build Coastguard Worker 302*523fa7a6SAndroid Build Coastguard Worker def output( 303*523fa7a6SAndroid Build Coastguard Worker self, 304*523fa7a6SAndroid Build Coastguard Worker target: torch.fx.node.Target, 305*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 306*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 307*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 308*523fa7a6SAndroid Build Coastguard Worker return self.callback.output(args[0], NodeMetadata(self.node.meta)).data 309*523fa7a6SAndroid Build Coastguard Worker 310*523fa7a6SAndroid Build Coastguard Worker def call_function( 311*523fa7a6SAndroid Build Coastguard Worker self, 312*523fa7a6SAndroid Build Coastguard Worker target: torch.fx.node.Target, 313*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 314*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 315*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 316*523fa7a6SAndroid Build Coastguard Worker meta = NodeMetadata(self.node.meta) 317*523fa7a6SAndroid Build Coastguard Worker 318*523fa7a6SAndroid Build Coastguard Worker if target == operator.getitem: 319*523fa7a6SAndroid Build Coastguard Worker value, key = args 320*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_getitem(value, key, meta) 321*523fa7a6SAndroid Build Coastguard Worker elif getattr(target, "__module__", None) in { 322*523fa7a6SAndroid Build Coastguard Worker "_operator", 323*523fa7a6SAndroid Build Coastguard Worker "builtins", 324*523fa7a6SAndroid Build Coastguard Worker "math", 325*523fa7a6SAndroid Build Coastguard Worker }: 326*523fa7a6SAndroid Build Coastguard Worker assert callable(target) 327*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_sym(target, args, meta) 328*523fa7a6SAndroid Build Coastguard Worker elif target in _TORCH_SYM_OPS: 329*523fa7a6SAndroid Build Coastguard Worker assert callable(target) 330*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_sym(target, args, meta) 331*523fa7a6SAndroid Build Coastguard Worker elif isinstance( 332*523fa7a6SAndroid Build Coastguard Worker target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket) 333*523fa7a6SAndroid Build Coastguard Worker ): 334*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_operator( 335*523fa7a6SAndroid Build Coastguard Worker target, 336*523fa7a6SAndroid Build Coastguard Worker args, 337*523fa7a6SAndroid Build Coastguard Worker kwargs, 338*523fa7a6SAndroid Build Coastguard Worker meta, 339*523fa7a6SAndroid Build Coastguard Worker ) 340*523fa7a6SAndroid Build Coastguard Worker elif target == torch.ops.higher_order.cond: 341*523fa7a6SAndroid Build Coastguard Worker pred, true_fn, false_fn, inputs = args 342*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) 343*523fa7a6SAndroid Build Coastguard Worker elif target == torch.ops.higher_order.map_impl: 344*523fa7a6SAndroid Build Coastguard Worker f, mapped_args, operands = args # type: ignore[assignment] 345*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_map(f, mapped_args, operands, meta) 346*523fa7a6SAndroid Build Coastguard Worker # For other unregistered HigherOrderOps, just interpret them blindly 347*523fa7a6SAndroid Build Coastguard Worker elif isinstance(target, torch._ops.HigherOrderOperator): 348*523fa7a6SAndroid Build Coastguard Worker return self.callback._fx( 349*523fa7a6SAndroid Build Coastguard Worker "call_function", 350*523fa7a6SAndroid Build Coastguard Worker target, 351*523fa7a6SAndroid Build Coastguard Worker args, 352*523fa7a6SAndroid Build Coastguard Worker kwargs, 353*523fa7a6SAndroid Build Coastguard Worker meta, 354*523fa7a6SAndroid Build Coastguard Worker ) 355*523fa7a6SAndroid Build Coastguard Worker else: 356*523fa7a6SAndroid Build Coastguard Worker raise ExportPassBaseError(f"Unsupported target type: {target}") 357*523fa7a6SAndroid Build Coastguard Worker 358*523fa7a6SAndroid Build Coastguard Worker def get_attr( # pyre-fixme[14] 359*523fa7a6SAndroid Build Coastguard Worker self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] 360*523fa7a6SAndroid Build Coastguard Worker ) -> Argument: 361*523fa7a6SAndroid Build Coastguard Worker return super().get_attr(target, args, kwargs) 362*523fa7a6SAndroid Build Coastguard Worker 363*523fa7a6SAndroid Build Coastguard Worker def call_module( 364*523fa7a6SAndroid Build Coastguard Worker self, 365*523fa7a6SAndroid Build Coastguard Worker target: torch.fx.node.Target, 366*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 367*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 368*523fa7a6SAndroid Build Coastguard Worker ) -> None: 369*523fa7a6SAndroid Build Coastguard Worker raise ExportPassBaseError("call_module is not supported.") 370*523fa7a6SAndroid Build Coastguard Worker 371*523fa7a6SAndroid Build Coastguard Worker def call_method( # pyre-fixme[14] 372*523fa7a6SAndroid Build Coastguard Worker self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] 373*523fa7a6SAndroid Build Coastguard Worker ) -> None: 374*523fa7a6SAndroid Build Coastguard Worker raise ExportPassBaseError("call_method is not supported.") 375*523fa7a6SAndroid Build Coastguard Worker 376*523fa7a6SAndroid Build Coastguard Worker def run_node(self, n: torch.fx.Node) -> Argument: 377*523fa7a6SAndroid Build Coastguard Worker self.node = n 378*523fa7a6SAndroid Build Coastguard Worker self.callback.node_debug_str = n.format_node() 379*523fa7a6SAndroid Build Coastguard Worker return super().run_node(n) 380*523fa7a6SAndroid Build Coastguard Worker 381*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 382*523fa7a6SAndroid Build Coastguard Worker self.interpreter = torch.fx.Interpreter( 383*523fa7a6SAndroid Build Coastguard Worker torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) 384*523fa7a6SAndroid Build Coastguard Worker ) 385*523fa7a6SAndroid Build Coastguard Worker self.tracer = self.ExportTracer(self, CodeGen()) # pyre-ignore 386*523fa7a6SAndroid Build Coastguard Worker self.fake_tensor_mode: Optional[FakeTensorMode] = None 387*523fa7a6SAndroid Build Coastguard Worker self._initialized = True 388*523fa7a6SAndroid Build Coastguard Worker self.node_debug_str: Optional[str] = None 389*523fa7a6SAndroid Build Coastguard Worker 390*523fa7a6SAndroid Build Coastguard Worker def _fx( 391*523fa7a6SAndroid Build Coastguard Worker self, 392*523fa7a6SAndroid Build Coastguard Worker kind: str, 393*523fa7a6SAndroid Build Coastguard Worker target: torch.fx.node.Target, 394*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 395*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 396*523fa7a6SAndroid Build Coastguard Worker meta: NodeMetadata, 397*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 398*523fa7a6SAndroid Build Coastguard Worker args_data, kwargs_data = pytree.tree_map_only( 399*523fa7a6SAndroid Build Coastguard Worker ProxyValue, lambda x: x.data, (args, kwargs) 400*523fa7a6SAndroid Build Coastguard Worker ) 401*523fa7a6SAndroid Build Coastguard Worker res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data) 402*523fa7a6SAndroid Build Coastguard Worker args_proxy, kwargs_proxy = pytree.tree_map_only( 403*523fa7a6SAndroid Build Coastguard Worker ProxyValue, lambda x: x.proxy, (args, kwargs) 404*523fa7a6SAndroid Build Coastguard Worker ) 405*523fa7a6SAndroid Build Coastguard Worker 406*523fa7a6SAndroid Build Coastguard Worker name = None 407*523fa7a6SAndroid Build Coastguard Worker if isinstance(target, torch._ops.OpOverload): 408*523fa7a6SAndroid Build Coastguard Worker name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) 409*523fa7a6SAndroid Build Coastguard Worker 410*523fa7a6SAndroid Build Coastguard Worker res_proxy = self.tracer.create_proxy( 411*523fa7a6SAndroid Build Coastguard Worker kind, target, args_proxy, kwargs_proxy, name=name 412*523fa7a6SAndroid Build Coastguard Worker ) 413*523fa7a6SAndroid Build Coastguard Worker res_proxy.node.meta.update(meta.data) 414*523fa7a6SAndroid Build Coastguard Worker self.tracer.set_metadata(res_proxy.node, res_data) 415*523fa7a6SAndroid Build Coastguard Worker return ProxyValue(res_data, res_proxy) 416*523fa7a6SAndroid Build Coastguard Worker 417*523fa7a6SAndroid Build Coastguard Worker def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: 418*523fa7a6SAndroid Build Coastguard Worker # TODO(angelayi): Update this with what we decide to do for metadata in 419*523fa7a6SAndroid Build Coastguard Worker # the exported graph module 420*523fa7a6SAndroid Build Coastguard Worker if (args := graph_module.meta.get("args", None)) is not None: 421*523fa7a6SAndroid Build Coastguard Worker return list(args) 422*523fa7a6SAndroid Build Coastguard Worker 423*523fa7a6SAndroid Build Coastguard Worker def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]: 424*523fa7a6SAndroid Build Coastguard Worker if "val" in node.meta: 425*523fa7a6SAndroid Build Coastguard Worker fake = node.meta["val"] 426*523fa7a6SAndroid Build Coastguard Worker if hasattr(fake, "constant") and fake.constant is not None: 427*523fa7a6SAndroid Build Coastguard Worker return fake.constant 428*523fa7a6SAndroid Build Coastguard Worker return fake 429*523fa7a6SAndroid Build Coastguard Worker elif tensor_meta := node.meta.get("tensor_meta"): 430*523fa7a6SAndroid Build Coastguard Worker assert self.fake_tensor_mode is not None 431*523fa7a6SAndroid Build Coastguard Worker return FakeTensor( 432*523fa7a6SAndroid Build Coastguard Worker self.fake_tensor_mode, 433*523fa7a6SAndroid Build Coastguard Worker torch.empty( 434*523fa7a6SAndroid Build Coastguard Worker tensor_meta.shape, 435*523fa7a6SAndroid Build Coastguard Worker dtype=tensor_meta.dtype, 436*523fa7a6SAndroid Build Coastguard Worker device="meta", 437*523fa7a6SAndroid Build Coastguard Worker requires_grad=tensor_meta.requires_grad, 438*523fa7a6SAndroid Build Coastguard Worker memory_format=tensor_meta.memory_format, 439*523fa7a6SAndroid Build Coastguard Worker ), 440*523fa7a6SAndroid Build Coastguard Worker torch.device("cpu"), 441*523fa7a6SAndroid Build Coastguard Worker ) 442*523fa7a6SAndroid Build Coastguard Worker elif len(node.users) == 0: 443*523fa7a6SAndroid Build Coastguard Worker return None 444*523fa7a6SAndroid Build Coastguard Worker raise ExportPassBaseError( 445*523fa7a6SAndroid Build Coastguard Worker f"Cannot construct an input for graph module: {graph_module}.", 446*523fa7a6SAndroid Build Coastguard Worker ) 447*523fa7a6SAndroid Build Coastguard Worker 448*523fa7a6SAndroid Build Coastguard Worker return [ 449*523fa7a6SAndroid Build Coastguard Worker extract_input(node) 450*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes 451*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" 452*523fa7a6SAndroid Build Coastguard Worker ] 453*523fa7a6SAndroid Build Coastguard Worker 454*523fa7a6SAndroid Build Coastguard Worker def on_attr(self, attr: ProxyValue) -> None: 455*523fa7a6SAndroid Build Coastguard Worker pass 456*523fa7a6SAndroid Build Coastguard Worker 457*523fa7a6SAndroid Build Coastguard Worker def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue: 458*523fa7a6SAndroid Build Coastguard Worker arg_proxy = self.tracer.create_proxy("placeholder", name, (), {}) 459*523fa7a6SAndroid Build Coastguard Worker arg_proxy.node.meta = meta.data 460*523fa7a6SAndroid Build Coastguard Worker arg_proxy.node.meta["val"] = arg 461*523fa7a6SAndroid Build Coastguard Worker return ProxyValue(arg, arg_proxy) 462*523fa7a6SAndroid Build Coastguard Worker 463*523fa7a6SAndroid Build Coastguard Worker def call_operator( 464*523fa7a6SAndroid Build Coastguard Worker self, 465*523fa7a6SAndroid Build Coastguard Worker op, # pyre-ignore 466*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 467*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 468*523fa7a6SAndroid Build Coastguard Worker meta: NodeMetadata, 469*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 470*523fa7a6SAndroid Build Coastguard Worker return self._fx("call_function", op, args, kwargs, meta) 471*523fa7a6SAndroid Build Coastguard Worker 472*523fa7a6SAndroid Build Coastguard Worker def call_sym( 473*523fa7a6SAndroid Build Coastguard Worker self, 474*523fa7a6SAndroid Build Coastguard Worker target: Fn, 475*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 476*523fa7a6SAndroid Build Coastguard Worker meta: NodeMetadata, 477*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 478*523fa7a6SAndroid Build Coastguard Worker return self._fx("call_function", target, args, {}, meta) 479*523fa7a6SAndroid Build Coastguard Worker 480*523fa7a6SAndroid Build Coastguard Worker def call_cond( 481*523fa7a6SAndroid Build Coastguard Worker self, 482*523fa7a6SAndroid Build Coastguard Worker pred: ProxyValue, 483*523fa7a6SAndroid Build Coastguard Worker true_fn: torch.fx.GraphModule, 484*523fa7a6SAndroid Build Coastguard Worker false_fn: torch.fx.GraphModule, 485*523fa7a6SAndroid Build Coastguard Worker inputs: List[Argument], 486*523fa7a6SAndroid Build Coastguard Worker meta: NodeMetadata, 487*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 488*523fa7a6SAndroid Build Coastguard Worker true_branch = self.call_submodule(true_fn, tuple(inputs)) 489*523fa7a6SAndroid Build Coastguard Worker false_branch = self.call_submodule(false_fn, tuple(inputs)) 490*523fa7a6SAndroid Build Coastguard Worker assert true_branch is not None 491*523fa7a6SAndroid Build Coastguard Worker assert false_branch is not None 492*523fa7a6SAndroid Build Coastguard Worker return self._fx( 493*523fa7a6SAndroid Build Coastguard Worker "call_function", 494*523fa7a6SAndroid Build Coastguard Worker torch.ops.higher_order.cond, 495*523fa7a6SAndroid Build Coastguard Worker (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)), 496*523fa7a6SAndroid Build Coastguard Worker {}, 497*523fa7a6SAndroid Build Coastguard Worker meta, 498*523fa7a6SAndroid Build Coastguard Worker ) 499*523fa7a6SAndroid Build Coastguard Worker 500*523fa7a6SAndroid Build Coastguard Worker def call_map( 501*523fa7a6SAndroid Build Coastguard Worker self, 502*523fa7a6SAndroid Build Coastguard Worker f: torch.fx.GraphModule, 503*523fa7a6SAndroid Build Coastguard Worker mapped_args: List[ProxyValue], 504*523fa7a6SAndroid Build Coastguard Worker operands: List[ProxyValue], 505*523fa7a6SAndroid Build Coastguard Worker meta: NodeMetadata, 506*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 507*523fa7a6SAndroid Build Coastguard Worker xs = _unstack_pytree([arg.data for arg in mapped_args])[0] 508*523fa7a6SAndroid Build Coastguard Worker f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands])) 509*523fa7a6SAndroid Build Coastguard Worker assert f_branch is not None 510*523fa7a6SAndroid Build Coastguard Worker return self._fx( 511*523fa7a6SAndroid Build Coastguard Worker "call_function", 512*523fa7a6SAndroid Build Coastguard Worker torch.ops.higher_order.map_impl, 513*523fa7a6SAndroid Build Coastguard Worker (f_branch.graph_module, mapped_args, operands), 514*523fa7a6SAndroid Build Coastguard Worker {}, 515*523fa7a6SAndroid Build Coastguard Worker meta, 516*523fa7a6SAndroid Build Coastguard Worker ) 517*523fa7a6SAndroid Build Coastguard Worker 518*523fa7a6SAndroid Build Coastguard Worker def call_getitem( 519*523fa7a6SAndroid Build Coastguard Worker self, value: ProxyValue, key: int, meta: NodeMetadata 520*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 521*523fa7a6SAndroid Build Coastguard Worker return self._fx("call_function", operator.getitem, (value, key), {}, meta) 522*523fa7a6SAndroid Build Coastguard Worker 523*523fa7a6SAndroid Build Coastguard Worker def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: 524*523fa7a6SAndroid Build Coastguard Worker return self._fx("output", "output", (results,), {}, meta) 525*523fa7a6SAndroid Build Coastguard Worker 526*523fa7a6SAndroid Build Coastguard Worker def call_submodule( 527*523fa7a6SAndroid Build Coastguard Worker self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] 528*523fa7a6SAndroid Build Coastguard Worker ) -> PassResult: 529*523fa7a6SAndroid Build Coastguard Worker prev_tracer, self.tracer = self.tracer, self.ExportTracer( 530*523fa7a6SAndroid Build Coastguard Worker self, graph_module.graph._codegen 531*523fa7a6SAndroid Build Coastguard Worker ) 532*523fa7a6SAndroid Build Coastguard Worker self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode 533*523fa7a6SAndroid Build Coastguard Worker interpreter = self.ExportInterpreter(self, graph_module) 534*523fa7a6SAndroid Build Coastguard Worker prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( 535*523fa7a6SAndroid Build Coastguard Worker torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) 536*523fa7a6SAndroid Build Coastguard Worker ) 537*523fa7a6SAndroid Build Coastguard Worker inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs) 538*523fa7a6SAndroid Build Coastguard Worker with fx_traceback.preserve_node_meta(): 539*523fa7a6SAndroid Build Coastguard Worker interpreter.run(*inputs_data) 540*523fa7a6SAndroid Build Coastguard Worker 541*523fa7a6SAndroid Build Coastguard Worker new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph) 542*523fa7a6SAndroid Build Coastguard Worker 543*523fa7a6SAndroid Build Coastguard Worker self.tracer = prev_tracer 544*523fa7a6SAndroid Build Coastguard Worker self.interpreter = prev_interpreter 545*523fa7a6SAndroid Build Coastguard Worker return PassResult( 546*523fa7a6SAndroid Build Coastguard Worker new_graph_module, 547*523fa7a6SAndroid Build Coastguard Worker True, 548*523fa7a6SAndroid Build Coastguard Worker ) 549*523fa7a6SAndroid Build Coastguard Worker 550*523fa7a6SAndroid Build Coastguard Worker def call(self, graph_module: fx.GraphModule) -> PassResult: 551*523fa7a6SAndroid Build Coastguard Worker if not getattr(self, "_initialized", False): 552*523fa7a6SAndroid Build Coastguard Worker raise ExportPassBaseError( 553*523fa7a6SAndroid Build Coastguard Worker "ExportPass is not initialized with __init__().", 554*523fa7a6SAndroid Build Coastguard Worker ) 555*523fa7a6SAndroid Build Coastguard Worker 556*523fa7a6SAndroid Build Coastguard Worker inputs = self.inputs(graph_module) 557*523fa7a6SAndroid Build Coastguard Worker 558*523fa7a6SAndroid Build Coastguard Worker fake_tensor_mode = None 559*523fa7a6SAndroid Build Coastguard Worker for i in inputs: 560*523fa7a6SAndroid Build Coastguard Worker if isinstance(i, FakeTensor): 561*523fa7a6SAndroid Build Coastguard Worker assert ( 562*523fa7a6SAndroid Build Coastguard Worker fake_tensor_mode is None or fake_tensor_mode is i.fake_mode 563*523fa7a6SAndroid Build Coastguard Worker ), "Multiple fake tensor mode detected." 564*523fa7a6SAndroid Build Coastguard Worker fake_tensor_mode = i.fake_mode 565*523fa7a6SAndroid Build Coastguard Worker if fake_tensor_mode is None: 566*523fa7a6SAndroid Build Coastguard Worker self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True) 567*523fa7a6SAndroid Build Coastguard Worker fake_tensor_mode = nullcontext() # type: ignore[assignment] 568*523fa7a6SAndroid Build Coastguard Worker dispatcher_mode = nullcontext() # type: ignore[assignment] 569*523fa7a6SAndroid Build Coastguard Worker else: 570*523fa7a6SAndroid Build Coastguard Worker fake_tensor_mode.allow_non_fake_inputs = True 571*523fa7a6SAndroid Build Coastguard Worker self.tracer.fake_tensor_mode = fake_tensor_mode 572*523fa7a6SAndroid Build Coastguard Worker dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment] 573*523fa7a6SAndroid Build Coastguard Worker self.fake_tensor_mode = self.tracer.fake_tensor_mode 574*523fa7a6SAndroid Build Coastguard Worker 575*523fa7a6SAndroid Build Coastguard Worker with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr] 576*523fa7a6SAndroid Build Coastguard Worker result = self.call_submodule(graph_module, tuple(inputs)) 577*523fa7a6SAndroid Build Coastguard Worker 578*523fa7a6SAndroid Build Coastguard Worker return result 579*523fa7a6SAndroid Build Coastguard Worker 580*523fa7a6SAndroid Build Coastguard Worker 581*523fa7a6SAndroid Build Coastguard Workerclass ExportPass(_ExportPassBase): 582*523fa7a6SAndroid Build Coastguard Worker class ExportTracer(_ExportPassBase.ExportTracer): 583*523fa7a6SAndroid Build Coastguard Worker def create_arg(self, a: Argument) -> torch.fx.Node: 584*523fa7a6SAndroid Build Coastguard Worker if isinstance(a, torch.nn.Module): 585*523fa7a6SAndroid Build Coastguard Worker if a not in self.submodules: 586*523fa7a6SAndroid Build Coastguard Worker prefix = "lowered_module" if is_lowered_module(a) else "submodule" 587*523fa7a6SAndroid Build Coastguard Worker name_submodule = f"{prefix}_{len(self.submodules)}" 588*523fa7a6SAndroid Build Coastguard Worker self.root.add_module(name_submodule, a) 589*523fa7a6SAndroid Build Coastguard Worker self.submodules[a] = name_submodule 590*523fa7a6SAndroid Build Coastguard Worker return super().create_arg(a) 591*523fa7a6SAndroid Build Coastguard Worker 592*523fa7a6SAndroid Build Coastguard Worker class ExportInterpreter(_ExportPassBase.ExportInterpreter): 593*523fa7a6SAndroid Build Coastguard Worker """ 594*523fa7a6SAndroid Build Coastguard Worker Interpreter to callback on any ExportPassBase functions 595*523fa7a6SAndroid Build Coastguard Worker """ 596*523fa7a6SAndroid Build Coastguard Worker 597*523fa7a6SAndroid Build Coastguard Worker def __init__(self, callback: "ExportPass", gm: fx.GraphModule) -> None: 598*523fa7a6SAndroid Build Coastguard Worker super().__init__(callback, gm) 599*523fa7a6SAndroid Build Coastguard Worker 600*523fa7a6SAndroid Build Coastguard Worker def call_function( 601*523fa7a6SAndroid Build Coastguard Worker self, 602*523fa7a6SAndroid Build Coastguard Worker target: torch.fx.node.Target, 603*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Argument, ...], 604*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 605*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 606*523fa7a6SAndroid Build Coastguard Worker meta = NodeMetadata(self.node.meta) 607*523fa7a6SAndroid Build Coastguard Worker if target == operator.getitem: 608*523fa7a6SAndroid Build Coastguard Worker value, key = args 609*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_getitem(value, key, meta) 610*523fa7a6SAndroid Build Coastguard Worker elif isinstance(target, EdgeOpOverload): 611*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_operator( 612*523fa7a6SAndroid Build Coastguard Worker target, 613*523fa7a6SAndroid Build Coastguard Worker args, 614*523fa7a6SAndroid Build Coastguard Worker kwargs, 615*523fa7a6SAndroid Build Coastguard Worker meta, 616*523fa7a6SAndroid Build Coastguard Worker ) 617*523fa7a6SAndroid Build Coastguard Worker 618*523fa7a6SAndroid Build Coastguard Worker # TODO according to zhengxu ExportPassBase should not be aware of 619*523fa7a6SAndroid Build Coastguard Worker # memory.alloc. Check this comment: 620*523fa7a6SAndroid Build Coastguard Worker # https://www.internalfb.com/diff/D42758019?dst_version_fbid=5906016402813292&transaction_fbid=1104713900200176 621*523fa7a6SAndroid Build Coastguard Worker elif target == memory.alloc: 622*523fa7a6SAndroid Build Coastguard Worker return self.callback._fx( 623*523fa7a6SAndroid Build Coastguard Worker "call_function", 624*523fa7a6SAndroid Build Coastguard Worker target, 625*523fa7a6SAndroid Build Coastguard Worker args, 626*523fa7a6SAndroid Build Coastguard Worker kwargs, 627*523fa7a6SAndroid Build Coastguard Worker meta, 628*523fa7a6SAndroid Build Coastguard Worker ) 629*523fa7a6SAndroid Build Coastguard Worker 630*523fa7a6SAndroid Build Coastguard Worker elif target == executorch_call_delegate: 631*523fa7a6SAndroid Build Coastguard Worker lowered_module = args[0] 632*523fa7a6SAndroid Build Coastguard Worker args = args[1:] 633*523fa7a6SAndroid Build Coastguard Worker return self.callback.call_delegate( # pyre-ignore 634*523fa7a6SAndroid Build Coastguard Worker lowered_module, 635*523fa7a6SAndroid Build Coastguard Worker args, 636*523fa7a6SAndroid Build Coastguard Worker kwargs, 637*523fa7a6SAndroid Build Coastguard Worker NodeMetadata(self.node.meta), 638*523fa7a6SAndroid Build Coastguard Worker ) 639*523fa7a6SAndroid Build Coastguard Worker 640*523fa7a6SAndroid Build Coastguard Worker return super().call_function(target, args, kwargs) 641*523fa7a6SAndroid Build Coastguard Worker 642*523fa7a6SAndroid Build Coastguard Worker def call_delegate( 643*523fa7a6SAndroid Build Coastguard Worker self, 644*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. 645*523fa7a6SAndroid Build Coastguard Worker lowered_module: "LoweredBackendModule", # noqa 646*523fa7a6SAndroid Build Coastguard Worker args: Tuple[ProxyValue, ...], 647*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 648*523fa7a6SAndroid Build Coastguard Worker meta: NodeMetadata, 649*523fa7a6SAndroid Build Coastguard Worker ) -> ProxyValue: 650*523fa7a6SAndroid Build Coastguard Worker args = (lowered_module,) + args 651*523fa7a6SAndroid Build Coastguard Worker return self._fx( 652*523fa7a6SAndroid Build Coastguard Worker "call_function", 653*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate, 654*523fa7a6SAndroid Build Coastguard Worker args, 655*523fa7a6SAndroid Build Coastguard Worker kwargs, 656*523fa7a6SAndroid Build Coastguard Worker meta, 657*523fa7a6SAndroid Build Coastguard Worker ) 658*523fa7a6SAndroid Build Coastguard Worker 659*523fa7a6SAndroid Build Coastguard Worker def call_submodule( 660*523fa7a6SAndroid Build Coastguard Worker self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] 661*523fa7a6SAndroid Build Coastguard Worker ) -> PassResult: 662*523fa7a6SAndroid Build Coastguard Worker res = super().call_submodule(graph_module, inputs) 663*523fa7a6SAndroid Build Coastguard Worker 664*523fa7a6SAndroid Build Coastguard Worker def preserve_original_ph_meta_val( 665*523fa7a6SAndroid Build Coastguard Worker gm: torch.fx.GraphModule, new_gm: torch.fx.GraphModule 666*523fa7a6SAndroid Build Coastguard Worker ) -> None: 667*523fa7a6SAndroid Build Coastguard Worker def get_phs(gm: torch.fx.GraphModule) -> List[torch.fx.Node]: 668*523fa7a6SAndroid Build Coastguard Worker return [node for node in gm.graph.nodes if node.op == "placeholder"] 669*523fa7a6SAndroid Build Coastguard Worker 670*523fa7a6SAndroid Build Coastguard Worker def migrate_meta_val( 671*523fa7a6SAndroid Build Coastguard Worker orig_phs: List[torch.fx.Node], new_phs: List[torch.fx.Node] 672*523fa7a6SAndroid Build Coastguard Worker ) -> None: 673*523fa7a6SAndroid Build Coastguard Worker if len(orig_phs) != len(new_phs): 674*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 675*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 676*523fa7a6SAndroid Build Coastguard Worker "ExportPassBase doesn't support changing the placeholders", 677*523fa7a6SAndroid Build Coastguard Worker ) 678*523fa7a6SAndroid Build Coastguard Worker for ph, new_ph in zip(orig_phs, new_phs): 679*523fa7a6SAndroid Build Coastguard Worker if isinstance(new_ph.meta["val"], torch.Tensor): 680*523fa7a6SAndroid Build Coastguard Worker if ( 681*523fa7a6SAndroid Build Coastguard Worker not isinstance(ph.meta["val"], torch.Tensor) 682*523fa7a6SAndroid Build Coastguard Worker or new_ph.meta["val"].size() != ph.meta["val"].size() 683*523fa7a6SAndroid Build Coastguard Worker ): 684*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 685*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 686*523fa7a6SAndroid Build Coastguard Worker "ExportPassBase doesn't support changing the placeholders", 687*523fa7a6SAndroid Build Coastguard Worker ) 688*523fa7a6SAndroid Build Coastguard Worker new_ph.meta["val"] = ph.meta["val"] 689*523fa7a6SAndroid Build Coastguard Worker 690*523fa7a6SAndroid Build Coastguard Worker migrate_meta_val(get_phs(gm), get_phs(new_gm)) 691*523fa7a6SAndroid Build Coastguard Worker 692*523fa7a6SAndroid Build Coastguard Worker # After one pass, new_graph_module's placeholders will always hold fake tensors in 693*523fa7a6SAndroid Build Coastguard Worker # meta['val'] but sometimes we want to preserve the original meta['val'] of placeholders 694*523fa7a6SAndroid Build Coastguard Worker # 695*523fa7a6SAndroid Build Coastguard Worker # For example, custom flows and certain passes assume no fake_tensor_mode is activated 696*523fa7a6SAndroid Build Coastguard Worker # and it doesn't quite work with fake_tensor_mode. but we don't bother to fix them. 697*523fa7a6SAndroid Build Coastguard Worker # So we'll just reset the meta of placeholders to its original value. It's safe because that 698*523fa7a6SAndroid Build Coastguard Worker # 1. For models captured with pt2_mode, the meta['val'] of placeholders are fake_tensors already, so 699*523fa7a6SAndroid Build Coastguard Worker # preserving it to the new graph module won't hurt. 700*523fa7a6SAndroid Build Coastguard Worker # 2. For models captured with dispatch_trace, the meta['val'] field 701*523fa7a6SAndroid Build Coastguard Worker # Note that it's only safe when passes don't modify the inputs. 702*523fa7a6SAndroid Build Coastguard Worker preserve_original_ph_meta_val(graph_module, res.graph_module) 703*523fa7a6SAndroid Build Coastguard Worker 704*523fa7a6SAndroid Build Coastguard Worker return res 705*523fa7a6SAndroid Build Coastguard Worker 706*523fa7a6SAndroid Build Coastguard Worker 707*523fa7a6SAndroid Build Coastguard Worker@runtime_checkable 708*523fa7a6SAndroid Build Coastguard Workerclass ArgSchema(Protocol): 709*523fa7a6SAndroid Build Coastguard Worker name: str 710*523fa7a6SAndroid Build Coastguard Worker kwarg_only: bool 711*523fa7a6SAndroid Build Coastguard Worker type: Any # pyre-ignore 712*523fa7a6SAndroid Build Coastguard Worker 713*523fa7a6SAndroid Build Coastguard Worker 714*523fa7a6SAndroid Build Coastguard Workerdef map_args( 715*523fa7a6SAndroid Build Coastguard Worker op: torch._ops.OpOverload, 716*523fa7a6SAndroid Build Coastguard Worker fn: Fn, 717*523fa7a6SAndroid Build Coastguard Worker args: Argument, 718*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Argument], 719*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[Argument, Dict[str, Argument]]: 720*523fa7a6SAndroid Build Coastguard Worker assert isinstance(args, tuple) 721*523fa7a6SAndroid Build Coastguard Worker assert isinstance(kwargs, dict) 722*523fa7a6SAndroid Build Coastguard Worker args = list(args) 723*523fa7a6SAndroid Build Coastguard Worker kwargs = kwargs.copy() 724*523fa7a6SAndroid Build Coastguard Worker 725*523fa7a6SAndroid Build Coastguard Worker def update(key: K, args: MutableMapping[K, PyTree], schema: ArgSchema) -> None: 726*523fa7a6SAndroid Build Coastguard Worker args[key] = fn(args[key], schema) 727*523fa7a6SAndroid Build Coastguard Worker 728*523fa7a6SAndroid Build Coastguard Worker for i, schema in enumerate(op._schema.arguments): 729*523fa7a6SAndroid Build Coastguard Worker if schema.name in kwargs: 730*523fa7a6SAndroid Build Coastguard Worker update(schema.name, kwargs, schema) 731*523fa7a6SAndroid Build Coastguard Worker elif not schema.kwarg_only and i < len(args): 732*523fa7a6SAndroid Build Coastguard Worker update(i, args, schema) # pyre-ignore 733*523fa7a6SAndroid Build Coastguard Worker 734*523fa7a6SAndroid Build Coastguard Worker return tuple(args), kwargs 735