1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8import warnings 9from collections import namedtuple 10from contextlib import contextmanager 11from types import MethodType 12from typing import Any, Callable, cast, List, Optional, Tuple 13 14import torch 15from executorch.exir.capture._config import CaptureConfig 16from executorch.exir.error import ExportError, ExportErrorType, InternalError 17from executorch.exir.program import ExirExportedProgram 18from executorch.exir.program._program import _transform, HackedUpExportedProgramDONOTUSE 19from executorch.exir.tracer import ( 20 _default_decomposition_table, 21 dispatch_trace, 22 dynamo_trace, 23 flatten_output, 24 Value, 25) 26from executorch.exir.verification.verifier import EXIRATenDialectVerifierBase 27from torch import _guards 28from torch._dispatch.python import enable_python_dispatcher 29from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass 30from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 31from torch.export import export 32from torch.export.exported_program import ( 33 ExportedProgram, 34 ExportGraphSignature, 35 InputKind, 36 InputSpec, 37 ModuleCallEntry, 38 ModuleCallSignature, 39 OutputKind, 40 OutputSpec, 41 TensorArgument, 42) 43from torch.func import functionalize 44from torch.fx._compatibility import compatibility 45from torch.fx.experimental.proxy_tensor import make_fx 46from torch.fx.experimental.symbolic_shapes import ShapeEnv 47from torch.utils import _pytree as pytree 48 49 50Val = Any 51 52 53CompileSpec = namedtuple( 54 "CompileSpec", ["method_name", "callable", "args", "dynamic_shapes"] 55) 56 57 58CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) 59 60 61@compatibility(is_backward_compatible=False) 62def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram: 63 """ 64 This is a legacy API that should be avoided. Prefer to use capture() instead. 65 """ 66 warnings.warn( 67 "This function is now deprecated, please use `torch.export and exir.to_edge` instead. " 68 "See https://github.com/pytorch/functorch for more details.", 69 DeprecationWarning, 70 stacklevel=1, 71 ) 72 73 graph_module = dispatch_trace(f, args) 74 flat_args = tuple(pytree.tree_flatten(args)[0]) 75 in_spec, out_spec = graph_module.in_spec, graph_module.out_spec 76 77 _instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args) 78 graph_module._apply(torch.Tensor.contiguous) 79 80 user_inputs = [ 81 node.name for node in graph_module.graph.nodes if node.op == "placeholder" 82 ] 83 output_node = list(graph_module.graph.nodes)[-1] 84 assert output_node.op == "output" 85 user_outputs = [arg.name for arg in output_node.args[0]] 86 87 for n in graph_module.graph.nodes: 88 if n.op == "call_function" and "val" not in n.meta: 89 try: 90 args, kwargs = pytree.tree_map_only( 91 torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs) 92 ) 93 n.meta["val"] = n.target(*args, **kwargs) 94 except Exception: 95 n.meta["val"] = None 96 97 ep = HackedUpExportedProgramDONOTUSE( 98 root=graph_module, 99 graph=graph_module.graph, 100 graph_signature=ExportGraphSignature( 101 input_specs=[ 102 InputSpec( 103 kind=InputKind.USER_INPUT, arg=TensorArgument(name=i), target=None 104 ) 105 for i in user_inputs 106 ], 107 output_specs=[ 108 OutputSpec( 109 kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=o), target=None 110 ) 111 for o in user_outputs 112 ], 113 ), 114 call_spec=CallSpec(in_spec, out_spec), 115 state_dict={}, 116 range_constraints={}, 117 module_call_graph=[ 118 ModuleCallEntry( 119 fqn="", 120 signature=ModuleCallSignature( 121 inputs=[], 122 outputs=[], 123 # pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got 124 # `Union[Tensor, Module]`. 125 in_spec=in_spec, 126 # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got 127 # `Union[Tensor, Module]`. 128 out_spec=out_spec, 129 ), 130 ) 131 ], 132 example_inputs=None, 133 verifier=EXIRATenDialectVerifierBase, 134 ) 135 return ExirExportedProgram(ep, False) 136 137 138@contextmanager 139def patch_forward(obj: torch.nn.Module, new_method): 140 """Helper method to make it easier to cleanly torch.export() a method on a 141 module that is not `forward`. 142 143 TODO(suo): upstream this to torch.export.wrapper. 144 """ 145 # Save the original method 146 original_method = obj.forward 147 148 # Patch the method 149 obj.forward = new_method.__get__(obj, obj.__class__) 150 151 try: 152 yield 153 finally: 154 # Restore the original method 155 obj.forward = original_method 156 157 158class WrapperModule(torch.nn.Module): 159 def __init__(self, f): 160 super().__init__() 161 self.forward = f 162 163 164@compatibility(is_backward_compatible=False) 165def capture( # noqa: C901 166 f: Callable[..., Any], 167 args: Tuple[Value, ...], 168 config: Optional[CaptureConfig] = None, 169 dynamic_shapes: Optional[List[Any]] = None, 170) -> ExirExportedProgram: 171 warnings.warn( 172 "This function is now deprecated, please use `torch.export and exir.to_edge` instead. ", 173 DeprecationWarning, 174 stacklevel=1, 175 ) 176 if not isinstance(args, tuple): 177 raise ExportError( 178 ExportErrorType.INVALID_INPUT_TYPE, 179 f"Expect `args` to be a tuple, got type: {type(args)}.", 180 ) 181 182 config = config or CaptureConfig() 183 out_spec = None 184 # TODO (zhxchen17) Always functionalize in a second pass no matter which path is taken. 185 flat_args = tuple(pytree.tree_flatten(args)[0]) 186 if not config.enable_aot: 187 if config._unlift: 188 raise ExportError( 189 ExportErrorType.NOT_SUPPORTED, 190 "_unlift config doesn't do anything without enable_aot enabled. Please do not set it", 191 ) 192 if config.pt2_mode: 193 if config.enable_aot: 194 if config.enable_dynamic_shape: 195 raise ExportError( 196 ExportErrorType.NOT_SUPPORTED, 197 "Under enable_aot, enable_dynamic_shapes flag doesn't do anything. Please do not set it", 198 ) 199 if not config.enable_functionalization: 200 raise ExportError( 201 ExportErrorType.NOT_SUPPORTED, 202 "Functionalization is required for enable_aot.", 203 ) 204 205 # If trying to capture a method and the bound class instance is a 206 # Module, then export the module while patching in that method. 207 if isinstance(f, MethodType) and isinstance(f.__self__, torch.nn.Module): 208 with patch_forward(f.__self__, f): 209 ep = export( 210 cast(torch.nn.Module, f.__self__), 211 args, 212 dynamic_shapes=dynamic_shapes, 213 ) 214 else: 215 mod = f if isinstance(f, torch.nn.Module) else WrapperModule(f) 216 ep = export(mod, args, dynamic_shapes=dynamic_shapes) 217 218 ep = ep.run_decompositions(_default_decomposition_table()) 219 ep = _transform(ep, ReplaceViewOpsWithViewCopyOpsPass()) 220 if not config._unlift: 221 return ExirExportedProgram(ep, False) 222 graph_module = cast(torch.fx.GraphModule, ep.module()) 223 224 elif config.enable_dynamic_shape: 225 graph_module, _ = dynamo_trace( 226 f, 227 args, 228 aten_graph=True, 229 tracing_mode="symbolic", 230 dynamo_config=config._dynamo_config, 231 dynamic_shapes=dynamic_shapes, 232 _use_old_decomp_table=config._use_old_decomp_table, 233 ) 234 235 else: 236 graph_module, _ = dynamo_trace( 237 f, 238 args, 239 aten_graph=True, 240 tracing_mode="fake", 241 dynamo_config=config._dynamo_config, 242 dynamic_shapes=None, 243 _use_old_decomp_table=config._use_old_decomp_table, 244 ) 245 246 if out_spec is None: 247 if isinstance(graph_module.graph._codegen, torch.fx.graph._PyTreeCodeGen): 248 out_spec = graph_module.graph._codegen.pytree_info.out_spec 249 elif hasattr(graph_module, "_out_spec"): 250 out_spec = graph_module._out_spec 251 else: 252 out_spec = pytree.tree_flatten(f(*args))[1] 253 254 # NOTE (tmanlaibaatar) 255 # torchdynamo.export adds extra kwarg into the graph module 256 # which is then lost while we are calling make_fx. This is because 257 # make_fx doesn't handle kwargs. Originally we used to use torchdynamo 258 # input spec, but due to some limitations in pytree implementation, it doesn't 259 # recognize the make_fx graph with torchdynamo input spec. We workaround it 260 # by getting the input spec directly from user argument. 261 in_spec = pytree.tree_flatten((args, {}))[1] 262 263 if config.enable_functionalization and not config.enable_aot: 264 args = copy.deepcopy(args) 265 266 def graph_with_interpreter(*args): 267 with torch.fx.traceback.preserve_node_meta(): 268 return torch.fx.Interpreter(graph_module).run(*args) 269 270 functionalized_callable = functionalize( 271 graph_with_interpreter, 272 remove="mutations_and_views", 273 ) 274 assert isinstance(functionalized_callable, Callable) 275 276 if config.enable_dynamic_shape: 277 fake_tensor_mode = FakeTensorMode( 278 allow_fallback_kernels=False, 279 allow_non_fake_inputs=True, 280 shape_env=ShapeEnv(), 281 ) 282 283 inps: List[torch.Tensor] = [] 284 for node in graph_module.graph.nodes: 285 if node.op == "placeholder" and "val" in node.meta: 286 example_fake_tensor = node.meta["val"] 287 assert isinstance(example_fake_tensor, FakeTensor) 288 inps.append(example_fake_tensor) 289 290 if detected_fake_mode := _guards.detect_fake_mode(inps): 291 fake_tensor_mode = detected_fake_mode 292 293 count = 0 294 295 def convert_to_fake(x): 296 nonlocal count 297 val = inps[count] 298 count += 1 299 return val 300 301 fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args) 302 303 with enable_python_dispatcher(), fake_tensor_mode: 304 graph_module = make_fx( 305 functionalized_callable, 306 tracing_mode="real", 307 _allow_non_fake_inputs=True, 308 )(*fake_args) 309 else: 310 # To avoid breaking folks, use the deprecated "real" tracing 311 # mode if we're not using pt2. 312 tracing_mode = "fake" if config.pt2_mode else "real" 313 graph_module = make_fx( 314 functionalized_callable, 315 tracing_mode=tracing_mode, 316 _allow_non_fake_inputs=True, 317 )(*args) 318 319 flatten_output(graph_module) 320 321 else: 322 raise InternalError("pt2=False path is officially deprecated") 323 324 _instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args) 325 graph_module._apply(torch.Tensor.contiguous) 326 327 user_inputs = [ 328 InputSpec( 329 kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None 330 ) 331 for node in graph_module.graph.nodes 332 if node.op == "placeholder" 333 ] 334 output_node = list(graph_module.graph.nodes)[-1] 335 assert output_node.op == "output" 336 user_outputs = [ 337 OutputSpec( 338 kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=arg.name), target=None 339 ) 340 for arg in output_node.args[0] 341 ] 342 343 graph_module.graph.eliminate_dead_code() 344 ep = ExportedProgram( 345 root=graph_module, 346 graph=graph_module.graph, 347 graph_signature=ExportGraphSignature(user_inputs, user_outputs), 348 state_dict={}, 349 range_constraints={}, 350 module_call_graph=[ 351 ModuleCallEntry( 352 fqn="", 353 signature=ModuleCallSignature( 354 inputs=[], 355 outputs=[], 356 in_spec=in_spec, 357 # pyre-fixme[6]: For 4th argument expected `TreeSpec` but got 358 # `Union[None, TreeSpec, Tensor, Module]`. 359 out_spec=out_spec, 360 ), 361 ) 362 ], 363 example_inputs=None, 364 verifiers=[EXIRATenDialectVerifierBase], 365 ) 366 return ExirExportedProgram(ep, False) 367 368 369# This is to bootstrap the missing meta["val"] when 1. ph consists of scalar 370# 2. meta["val"] is not properly set in dispatch_trace. 371def _instantiate_missing_placeholder_val_with_real_inputs(gm, args): 372 phs = [node for node in gm.graph.nodes if node.op == "placeholder"] 373 if len(phs) != len(args): 374 raise ExportError( 375 ExportErrorType.NOT_SUPPORTED, 376 "Expect number of placeholders to be the same as user inputs.", 377 ) 378 for node, arg in zip(phs, args): 379 if "val" not in node.meta or node.meta["val"] is None: 380 node.meta["val"] = arg 381