1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport copy 8*523fa7a6SAndroid Build Coastguard Workerimport warnings 9*523fa7a6SAndroid Build Coastguard Workerfrom collections import namedtuple 10*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import contextmanager 11*523fa7a6SAndroid Build Coastguard Workerfrom types import MethodType 12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Callable, cast, List, Optional, Tuple 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport torch 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.capture._config import CaptureConfig 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, InternalError 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program import ExirExportedProgram 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program._program import _transform, HackedUpExportedProgramDONOTUSE 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import ( 20*523fa7a6SAndroid Build Coastguard Worker _default_decomposition_table, 21*523fa7a6SAndroid Build Coastguard Worker dispatch_trace, 22*523fa7a6SAndroid Build Coastguard Worker dynamo_trace, 23*523fa7a6SAndroid Build Coastguard Worker flatten_output, 24*523fa7a6SAndroid Build Coastguard Worker Value, 25*523fa7a6SAndroid Build Coastguard Worker) 26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.verifier import EXIRATenDialectVerifierBase 27*523fa7a6SAndroid Build Coastguard Workerfrom torch import _guards 28*523fa7a6SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher 29*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass 30*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 31*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export 32*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ( 33*523fa7a6SAndroid Build Coastguard Worker ExportedProgram, 34*523fa7a6SAndroid Build Coastguard Worker ExportGraphSignature, 35*523fa7a6SAndroid Build Coastguard Worker InputKind, 36*523fa7a6SAndroid Build Coastguard Worker InputSpec, 37*523fa7a6SAndroid Build Coastguard Worker ModuleCallEntry, 38*523fa7a6SAndroid Build Coastguard Worker ModuleCallSignature, 39*523fa7a6SAndroid Build Coastguard Worker OutputKind, 40*523fa7a6SAndroid Build Coastguard Worker OutputSpec, 41*523fa7a6SAndroid Build Coastguard Worker TensorArgument, 42*523fa7a6SAndroid Build Coastguard Worker) 43*523fa7a6SAndroid Build Coastguard Workerfrom torch.func import functionalize 44*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx._compatibility import compatibility 45*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx 46*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import ShapeEnv 47*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard WorkerVal = Any 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard WorkerCompileSpec = namedtuple( 54*523fa7a6SAndroid Build Coastguard Worker "CompileSpec", ["method_name", "callable", "args", "dynamic_shapes"] 55*523fa7a6SAndroid Build Coastguard Worker) 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard WorkerCallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 62*523fa7a6SAndroid Build Coastguard Workerdef _capture_legacy_do_not_use(f, args) -> ExirExportedProgram: 63*523fa7a6SAndroid Build Coastguard Worker """ 64*523fa7a6SAndroid Build Coastguard Worker This is a legacy API that should be avoided. Prefer to use capture() instead. 65*523fa7a6SAndroid Build Coastguard Worker """ 66*523fa7a6SAndroid Build Coastguard Worker warnings.warn( 67*523fa7a6SAndroid Build Coastguard Worker "This function is now deprecated, please use `torch.export and exir.to_edge` instead. " 68*523fa7a6SAndroid Build Coastguard Worker "See https://github.com/pytorch/functorch for more details.", 69*523fa7a6SAndroid Build Coastguard Worker DeprecationWarning, 70*523fa7a6SAndroid Build Coastguard Worker stacklevel=1, 71*523fa7a6SAndroid Build Coastguard Worker ) 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker graph_module = dispatch_trace(f, args) 74*523fa7a6SAndroid Build Coastguard Worker flat_args = tuple(pytree.tree_flatten(args)[0]) 75*523fa7a6SAndroid Build Coastguard Worker in_spec, out_spec = graph_module.in_spec, graph_module.out_spec 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker _instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args) 78*523fa7a6SAndroid Build Coastguard Worker graph_module._apply(torch.Tensor.contiguous) 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker user_inputs = [ 81*523fa7a6SAndroid Build Coastguard Worker node.name for node in graph_module.graph.nodes if node.op == "placeholder" 82*523fa7a6SAndroid Build Coastguard Worker ] 83*523fa7a6SAndroid Build Coastguard Worker output_node = list(graph_module.graph.nodes)[-1] 84*523fa7a6SAndroid Build Coastguard Worker assert output_node.op == "output" 85*523fa7a6SAndroid Build Coastguard Worker user_outputs = [arg.name for arg in output_node.args[0]] 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Worker for n in graph_module.graph.nodes: 88*523fa7a6SAndroid Build Coastguard Worker if n.op == "call_function" and "val" not in n.meta: 89*523fa7a6SAndroid Build Coastguard Worker try: 90*523fa7a6SAndroid Build Coastguard Worker args, kwargs = pytree.tree_map_only( 91*523fa7a6SAndroid Build Coastguard Worker torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs) 92*523fa7a6SAndroid Build Coastguard Worker ) 93*523fa7a6SAndroid Build Coastguard Worker n.meta["val"] = n.target(*args, **kwargs) 94*523fa7a6SAndroid Build Coastguard Worker except Exception: 95*523fa7a6SAndroid Build Coastguard Worker n.meta["val"] = None 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker ep = HackedUpExportedProgramDONOTUSE( 98*523fa7a6SAndroid Build Coastguard Worker root=graph_module, 99*523fa7a6SAndroid Build Coastguard Worker graph=graph_module.graph, 100*523fa7a6SAndroid Build Coastguard Worker graph_signature=ExportGraphSignature( 101*523fa7a6SAndroid Build Coastguard Worker input_specs=[ 102*523fa7a6SAndroid Build Coastguard Worker InputSpec( 103*523fa7a6SAndroid Build Coastguard Worker kind=InputKind.USER_INPUT, arg=TensorArgument(name=i), target=None 104*523fa7a6SAndroid Build Coastguard Worker ) 105*523fa7a6SAndroid Build Coastguard Worker for i in user_inputs 106*523fa7a6SAndroid Build Coastguard Worker ], 107*523fa7a6SAndroid Build Coastguard Worker output_specs=[ 108*523fa7a6SAndroid Build Coastguard Worker OutputSpec( 109*523fa7a6SAndroid Build Coastguard Worker kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=o), target=None 110*523fa7a6SAndroid Build Coastguard Worker ) 111*523fa7a6SAndroid Build Coastguard Worker for o in user_outputs 112*523fa7a6SAndroid Build Coastguard Worker ], 113*523fa7a6SAndroid Build Coastguard Worker ), 114*523fa7a6SAndroid Build Coastguard Worker call_spec=CallSpec(in_spec, out_spec), 115*523fa7a6SAndroid Build Coastguard Worker state_dict={}, 116*523fa7a6SAndroid Build Coastguard Worker range_constraints={}, 117*523fa7a6SAndroid Build Coastguard Worker module_call_graph=[ 118*523fa7a6SAndroid Build Coastguard Worker ModuleCallEntry( 119*523fa7a6SAndroid Build Coastguard Worker fqn="", 120*523fa7a6SAndroid Build Coastguard Worker signature=ModuleCallSignature( 121*523fa7a6SAndroid Build Coastguard Worker inputs=[], 122*523fa7a6SAndroid Build Coastguard Worker outputs=[], 123*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got 124*523fa7a6SAndroid Build Coastguard Worker # `Union[Tensor, Module]`. 125*523fa7a6SAndroid Build Coastguard Worker in_spec=in_spec, 126*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got 127*523fa7a6SAndroid Build Coastguard Worker # `Union[Tensor, Module]`. 128*523fa7a6SAndroid Build Coastguard Worker out_spec=out_spec, 129*523fa7a6SAndroid Build Coastguard Worker ), 130*523fa7a6SAndroid Build Coastguard Worker ) 131*523fa7a6SAndroid Build Coastguard Worker ], 132*523fa7a6SAndroid Build Coastguard Worker example_inputs=None, 133*523fa7a6SAndroid Build Coastguard Worker verifier=EXIRATenDialectVerifierBase, 134*523fa7a6SAndroid Build Coastguard Worker ) 135*523fa7a6SAndroid Build Coastguard Worker return ExirExportedProgram(ep, False) 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker 138*523fa7a6SAndroid Build Coastguard Worker@contextmanager 139*523fa7a6SAndroid Build Coastguard Workerdef patch_forward(obj: torch.nn.Module, new_method): 140*523fa7a6SAndroid Build Coastguard Worker """Helper method to make it easier to cleanly torch.export() a method on a 141*523fa7a6SAndroid Build Coastguard Worker module that is not `forward`. 142*523fa7a6SAndroid Build Coastguard Worker 143*523fa7a6SAndroid Build Coastguard Worker TODO(suo): upstream this to torch.export.wrapper. 144*523fa7a6SAndroid Build Coastguard Worker """ 145*523fa7a6SAndroid Build Coastguard Worker # Save the original method 146*523fa7a6SAndroid Build Coastguard Worker original_method = obj.forward 147*523fa7a6SAndroid Build Coastguard Worker 148*523fa7a6SAndroid Build Coastguard Worker # Patch the method 149*523fa7a6SAndroid Build Coastguard Worker obj.forward = new_method.__get__(obj, obj.__class__) 150*523fa7a6SAndroid Build Coastguard Worker 151*523fa7a6SAndroid Build Coastguard Worker try: 152*523fa7a6SAndroid Build Coastguard Worker yield 153*523fa7a6SAndroid Build Coastguard Worker finally: 154*523fa7a6SAndroid Build Coastguard Worker # Restore the original method 155*523fa7a6SAndroid Build Coastguard Worker obj.forward = original_method 156*523fa7a6SAndroid Build Coastguard Worker 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Workerclass WrapperModule(torch.nn.Module): 159*523fa7a6SAndroid Build Coastguard Worker def __init__(self, f): 160*523fa7a6SAndroid Build Coastguard Worker super().__init__() 161*523fa7a6SAndroid Build Coastguard Worker self.forward = f 162*523fa7a6SAndroid Build Coastguard Worker 163*523fa7a6SAndroid Build Coastguard Worker 164*523fa7a6SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 165*523fa7a6SAndroid Build Coastguard Workerdef capture( # noqa: C901 166*523fa7a6SAndroid Build Coastguard Worker f: Callable[..., Any], 167*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Value, ...], 168*523fa7a6SAndroid Build Coastguard Worker config: Optional[CaptureConfig] = None, 169*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes: Optional[List[Any]] = None, 170*523fa7a6SAndroid Build Coastguard Worker) -> ExirExportedProgram: 171*523fa7a6SAndroid Build Coastguard Worker warnings.warn( 172*523fa7a6SAndroid Build Coastguard Worker "This function is now deprecated, please use `torch.export and exir.to_edge` instead. ", 173*523fa7a6SAndroid Build Coastguard Worker DeprecationWarning, 174*523fa7a6SAndroid Build Coastguard Worker stacklevel=1, 175*523fa7a6SAndroid Build Coastguard Worker ) 176*523fa7a6SAndroid Build Coastguard Worker if not isinstance(args, tuple): 177*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 178*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.INVALID_INPUT_TYPE, 179*523fa7a6SAndroid Build Coastguard Worker f"Expect `args` to be a tuple, got type: {type(args)}.", 180*523fa7a6SAndroid Build Coastguard Worker ) 181*523fa7a6SAndroid Build Coastguard Worker 182*523fa7a6SAndroid Build Coastguard Worker config = config or CaptureConfig() 183*523fa7a6SAndroid Build Coastguard Worker out_spec = None 184*523fa7a6SAndroid Build Coastguard Worker # TODO (zhxchen17) Always functionalize in a second pass no matter which path is taken. 185*523fa7a6SAndroid Build Coastguard Worker flat_args = tuple(pytree.tree_flatten(args)[0]) 186*523fa7a6SAndroid Build Coastguard Worker if not config.enable_aot: 187*523fa7a6SAndroid Build Coastguard Worker if config._unlift: 188*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 189*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 190*523fa7a6SAndroid Build Coastguard Worker "_unlift config doesn't do anything without enable_aot enabled. Please do not set it", 191*523fa7a6SAndroid Build Coastguard Worker ) 192*523fa7a6SAndroid Build Coastguard Worker if config.pt2_mode: 193*523fa7a6SAndroid Build Coastguard Worker if config.enable_aot: 194*523fa7a6SAndroid Build Coastguard Worker if config.enable_dynamic_shape: 195*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 196*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 197*523fa7a6SAndroid Build Coastguard Worker "Under enable_aot, enable_dynamic_shapes flag doesn't do anything. Please do not set it", 198*523fa7a6SAndroid Build Coastguard Worker ) 199*523fa7a6SAndroid Build Coastguard Worker if not config.enable_functionalization: 200*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 201*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 202*523fa7a6SAndroid Build Coastguard Worker "Functionalization is required for enable_aot.", 203*523fa7a6SAndroid Build Coastguard Worker ) 204*523fa7a6SAndroid Build Coastguard Worker 205*523fa7a6SAndroid Build Coastguard Worker # If trying to capture a method and the bound class instance is a 206*523fa7a6SAndroid Build Coastguard Worker # Module, then export the module while patching in that method. 207*523fa7a6SAndroid Build Coastguard Worker if isinstance(f, MethodType) and isinstance(f.__self__, torch.nn.Module): 208*523fa7a6SAndroid Build Coastguard Worker with patch_forward(f.__self__, f): 209*523fa7a6SAndroid Build Coastguard Worker ep = export( 210*523fa7a6SAndroid Build Coastguard Worker cast(torch.nn.Module, f.__self__), 211*523fa7a6SAndroid Build Coastguard Worker args, 212*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes=dynamic_shapes, 213*523fa7a6SAndroid Build Coastguard Worker ) 214*523fa7a6SAndroid Build Coastguard Worker else: 215*523fa7a6SAndroid Build Coastguard Worker mod = f if isinstance(f, torch.nn.Module) else WrapperModule(f) 216*523fa7a6SAndroid Build Coastguard Worker ep = export(mod, args, dynamic_shapes=dynamic_shapes) 217*523fa7a6SAndroid Build Coastguard Worker 218*523fa7a6SAndroid Build Coastguard Worker ep = ep.run_decompositions(_default_decomposition_table()) 219*523fa7a6SAndroid Build Coastguard Worker ep = _transform(ep, ReplaceViewOpsWithViewCopyOpsPass()) 220*523fa7a6SAndroid Build Coastguard Worker if not config._unlift: 221*523fa7a6SAndroid Build Coastguard Worker return ExirExportedProgram(ep, False) 222*523fa7a6SAndroid Build Coastguard Worker graph_module = cast(torch.fx.GraphModule, ep.module()) 223*523fa7a6SAndroid Build Coastguard Worker 224*523fa7a6SAndroid Build Coastguard Worker elif config.enable_dynamic_shape: 225*523fa7a6SAndroid Build Coastguard Worker graph_module, _ = dynamo_trace( 226*523fa7a6SAndroid Build Coastguard Worker f, 227*523fa7a6SAndroid Build Coastguard Worker args, 228*523fa7a6SAndroid Build Coastguard Worker aten_graph=True, 229*523fa7a6SAndroid Build Coastguard Worker tracing_mode="symbolic", 230*523fa7a6SAndroid Build Coastguard Worker dynamo_config=config._dynamo_config, 231*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes=dynamic_shapes, 232*523fa7a6SAndroid Build Coastguard Worker _use_old_decomp_table=config._use_old_decomp_table, 233*523fa7a6SAndroid Build Coastguard Worker ) 234*523fa7a6SAndroid Build Coastguard Worker 235*523fa7a6SAndroid Build Coastguard Worker else: 236*523fa7a6SAndroid Build Coastguard Worker graph_module, _ = dynamo_trace( 237*523fa7a6SAndroid Build Coastguard Worker f, 238*523fa7a6SAndroid Build Coastguard Worker args, 239*523fa7a6SAndroid Build Coastguard Worker aten_graph=True, 240*523fa7a6SAndroid Build Coastguard Worker tracing_mode="fake", 241*523fa7a6SAndroid Build Coastguard Worker dynamo_config=config._dynamo_config, 242*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes=None, 243*523fa7a6SAndroid Build Coastguard Worker _use_old_decomp_table=config._use_old_decomp_table, 244*523fa7a6SAndroid Build Coastguard Worker ) 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker if out_spec is None: 247*523fa7a6SAndroid Build Coastguard Worker if isinstance(graph_module.graph._codegen, torch.fx.graph._PyTreeCodeGen): 248*523fa7a6SAndroid Build Coastguard Worker out_spec = graph_module.graph._codegen.pytree_info.out_spec 249*523fa7a6SAndroid Build Coastguard Worker elif hasattr(graph_module, "_out_spec"): 250*523fa7a6SAndroid Build Coastguard Worker out_spec = graph_module._out_spec 251*523fa7a6SAndroid Build Coastguard Worker else: 252*523fa7a6SAndroid Build Coastguard Worker out_spec = pytree.tree_flatten(f(*args))[1] 253*523fa7a6SAndroid Build Coastguard Worker 254*523fa7a6SAndroid Build Coastguard Worker # NOTE (tmanlaibaatar) 255*523fa7a6SAndroid Build Coastguard Worker # torchdynamo.export adds extra kwarg into the graph module 256*523fa7a6SAndroid Build Coastguard Worker # which is then lost while we are calling make_fx. This is because 257*523fa7a6SAndroid Build Coastguard Worker # make_fx doesn't handle kwargs. Originally we used to use torchdynamo 258*523fa7a6SAndroid Build Coastguard Worker # input spec, but due to some limitations in pytree implementation, it doesn't 259*523fa7a6SAndroid Build Coastguard Worker # recognize the make_fx graph with torchdynamo input spec. We workaround it 260*523fa7a6SAndroid Build Coastguard Worker # by getting the input spec directly from user argument. 261*523fa7a6SAndroid Build Coastguard Worker in_spec = pytree.tree_flatten((args, {}))[1] 262*523fa7a6SAndroid Build Coastguard Worker 263*523fa7a6SAndroid Build Coastguard Worker if config.enable_functionalization and not config.enable_aot: 264*523fa7a6SAndroid Build Coastguard Worker args = copy.deepcopy(args) 265*523fa7a6SAndroid Build Coastguard Worker 266*523fa7a6SAndroid Build Coastguard Worker def graph_with_interpreter(*args): 267*523fa7a6SAndroid Build Coastguard Worker with torch.fx.traceback.preserve_node_meta(): 268*523fa7a6SAndroid Build Coastguard Worker return torch.fx.Interpreter(graph_module).run(*args) 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Worker functionalized_callable = functionalize( 271*523fa7a6SAndroid Build Coastguard Worker graph_with_interpreter, 272*523fa7a6SAndroid Build Coastguard Worker remove="mutations_and_views", 273*523fa7a6SAndroid Build Coastguard Worker ) 274*523fa7a6SAndroid Build Coastguard Worker assert isinstance(functionalized_callable, Callable) 275*523fa7a6SAndroid Build Coastguard Worker 276*523fa7a6SAndroid Build Coastguard Worker if config.enable_dynamic_shape: 277*523fa7a6SAndroid Build Coastguard Worker fake_tensor_mode = FakeTensorMode( 278*523fa7a6SAndroid Build Coastguard Worker allow_fallback_kernels=False, 279*523fa7a6SAndroid Build Coastguard Worker allow_non_fake_inputs=True, 280*523fa7a6SAndroid Build Coastguard Worker shape_env=ShapeEnv(), 281*523fa7a6SAndroid Build Coastguard Worker ) 282*523fa7a6SAndroid Build Coastguard Worker 283*523fa7a6SAndroid Build Coastguard Worker inps: List[torch.Tensor] = [] 284*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 285*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" and "val" in node.meta: 286*523fa7a6SAndroid Build Coastguard Worker example_fake_tensor = node.meta["val"] 287*523fa7a6SAndroid Build Coastguard Worker assert isinstance(example_fake_tensor, FakeTensor) 288*523fa7a6SAndroid Build Coastguard Worker inps.append(example_fake_tensor) 289*523fa7a6SAndroid Build Coastguard Worker 290*523fa7a6SAndroid Build Coastguard Worker if detected_fake_mode := _guards.detect_fake_mode(inps): 291*523fa7a6SAndroid Build Coastguard Worker fake_tensor_mode = detected_fake_mode 292*523fa7a6SAndroid Build Coastguard Worker 293*523fa7a6SAndroid Build Coastguard Worker count = 0 294*523fa7a6SAndroid Build Coastguard Worker 295*523fa7a6SAndroid Build Coastguard Worker def convert_to_fake(x): 296*523fa7a6SAndroid Build Coastguard Worker nonlocal count 297*523fa7a6SAndroid Build Coastguard Worker val = inps[count] 298*523fa7a6SAndroid Build Coastguard Worker count += 1 299*523fa7a6SAndroid Build Coastguard Worker return val 300*523fa7a6SAndroid Build Coastguard Worker 301*523fa7a6SAndroid Build Coastguard Worker fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args) 302*523fa7a6SAndroid Build Coastguard Worker 303*523fa7a6SAndroid Build Coastguard Worker with enable_python_dispatcher(), fake_tensor_mode: 304*523fa7a6SAndroid Build Coastguard Worker graph_module = make_fx( 305*523fa7a6SAndroid Build Coastguard Worker functionalized_callable, 306*523fa7a6SAndroid Build Coastguard Worker tracing_mode="real", 307*523fa7a6SAndroid Build Coastguard Worker _allow_non_fake_inputs=True, 308*523fa7a6SAndroid Build Coastguard Worker )(*fake_args) 309*523fa7a6SAndroid Build Coastguard Worker else: 310*523fa7a6SAndroid Build Coastguard Worker # To avoid breaking folks, use the deprecated "real" tracing 311*523fa7a6SAndroid Build Coastguard Worker # mode if we're not using pt2. 312*523fa7a6SAndroid Build Coastguard Worker tracing_mode = "fake" if config.pt2_mode else "real" 313*523fa7a6SAndroid Build Coastguard Worker graph_module = make_fx( 314*523fa7a6SAndroid Build Coastguard Worker functionalized_callable, 315*523fa7a6SAndroid Build Coastguard Worker tracing_mode=tracing_mode, 316*523fa7a6SAndroid Build Coastguard Worker _allow_non_fake_inputs=True, 317*523fa7a6SAndroid Build Coastguard Worker )(*args) 318*523fa7a6SAndroid Build Coastguard Worker 319*523fa7a6SAndroid Build Coastguard Worker flatten_output(graph_module) 320*523fa7a6SAndroid Build Coastguard Worker 321*523fa7a6SAndroid Build Coastguard Worker else: 322*523fa7a6SAndroid Build Coastguard Worker raise InternalError("pt2=False path is officially deprecated") 323*523fa7a6SAndroid Build Coastguard Worker 324*523fa7a6SAndroid Build Coastguard Worker _instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args) 325*523fa7a6SAndroid Build Coastguard Worker graph_module._apply(torch.Tensor.contiguous) 326*523fa7a6SAndroid Build Coastguard Worker 327*523fa7a6SAndroid Build Coastguard Worker user_inputs = [ 328*523fa7a6SAndroid Build Coastguard Worker InputSpec( 329*523fa7a6SAndroid Build Coastguard Worker kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None 330*523fa7a6SAndroid Build Coastguard Worker ) 331*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes 332*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" 333*523fa7a6SAndroid Build Coastguard Worker ] 334*523fa7a6SAndroid Build Coastguard Worker output_node = list(graph_module.graph.nodes)[-1] 335*523fa7a6SAndroid Build Coastguard Worker assert output_node.op == "output" 336*523fa7a6SAndroid Build Coastguard Worker user_outputs = [ 337*523fa7a6SAndroid Build Coastguard Worker OutputSpec( 338*523fa7a6SAndroid Build Coastguard Worker kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=arg.name), target=None 339*523fa7a6SAndroid Build Coastguard Worker ) 340*523fa7a6SAndroid Build Coastguard Worker for arg in output_node.args[0] 341*523fa7a6SAndroid Build Coastguard Worker ] 342*523fa7a6SAndroid Build Coastguard Worker 343*523fa7a6SAndroid Build Coastguard Worker graph_module.graph.eliminate_dead_code() 344*523fa7a6SAndroid Build Coastguard Worker ep = ExportedProgram( 345*523fa7a6SAndroid Build Coastguard Worker root=graph_module, 346*523fa7a6SAndroid Build Coastguard Worker graph=graph_module.graph, 347*523fa7a6SAndroid Build Coastguard Worker graph_signature=ExportGraphSignature(user_inputs, user_outputs), 348*523fa7a6SAndroid Build Coastguard Worker state_dict={}, 349*523fa7a6SAndroid Build Coastguard Worker range_constraints={}, 350*523fa7a6SAndroid Build Coastguard Worker module_call_graph=[ 351*523fa7a6SAndroid Build Coastguard Worker ModuleCallEntry( 352*523fa7a6SAndroid Build Coastguard Worker fqn="", 353*523fa7a6SAndroid Build Coastguard Worker signature=ModuleCallSignature( 354*523fa7a6SAndroid Build Coastguard Worker inputs=[], 355*523fa7a6SAndroid Build Coastguard Worker outputs=[], 356*523fa7a6SAndroid Build Coastguard Worker in_spec=in_spec, 357*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got 358*523fa7a6SAndroid Build Coastguard Worker # `Union[None, TreeSpec, Tensor, Module]`. 359*523fa7a6SAndroid Build Coastguard Worker out_spec=out_spec, 360*523fa7a6SAndroid Build Coastguard Worker ), 361*523fa7a6SAndroid Build Coastguard Worker ) 362*523fa7a6SAndroid Build Coastguard Worker ], 363*523fa7a6SAndroid Build Coastguard Worker example_inputs=None, 364*523fa7a6SAndroid Build Coastguard Worker verifiers=[EXIRATenDialectVerifierBase], 365*523fa7a6SAndroid Build Coastguard Worker ) 366*523fa7a6SAndroid Build Coastguard Worker return ExirExportedProgram(ep, False) 367*523fa7a6SAndroid Build Coastguard Worker 368*523fa7a6SAndroid Build Coastguard Worker 369*523fa7a6SAndroid Build Coastguard Worker# This is to bootstrap the missing meta["val"] when 1. ph consists of scalar 370*523fa7a6SAndroid Build Coastguard Worker# 2. meta["val"] is not properly set in dispatch_trace. 371*523fa7a6SAndroid Build Coastguard Workerdef _instantiate_missing_placeholder_val_with_real_inputs(gm, args): 372*523fa7a6SAndroid Build Coastguard Worker phs = [node for node in gm.graph.nodes if node.op == "placeholder"] 373*523fa7a6SAndroid Build Coastguard Worker if len(phs) != len(args): 374*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 375*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 376*523fa7a6SAndroid Build Coastguard Worker "Expect number of placeholders to be the same as user inputs.", 377*523fa7a6SAndroid Build Coastguard Worker ) 378*523fa7a6SAndroid Build Coastguard Worker for node, arg in zip(phs, args): 379*523fa7a6SAndroid Build Coastguard Worker if "val" not in node.meta or node.meta["val"] is None: 380*523fa7a6SAndroid Build Coastguard Worker node.meta["val"] = arg 381