1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport copy 10*523fa7a6SAndroid Build Coastguard Workerimport json 11*523fa7a6SAndroid Build Coastguard Workerimport traceback 12*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import contextmanager 13*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import asdict, dataclass 14*523fa7a6SAndroid Build Coastguard Workerfrom typing import ( 15*523fa7a6SAndroid Build Coastguard Worker Any, 16*523fa7a6SAndroid Build Coastguard Worker Callable, 17*523fa7a6SAndroid Build Coastguard Worker Dict, 18*523fa7a6SAndroid Build Coastguard Worker Generator, 19*523fa7a6SAndroid Build Coastguard Worker Iterable, 20*523fa7a6SAndroid Build Coastguard Worker List, 21*523fa7a6SAndroid Build Coastguard Worker Optional, 22*523fa7a6SAndroid Build Coastguard Worker Set, 23*523fa7a6SAndroid Build Coastguard Worker Tuple, 24*523fa7a6SAndroid Build Coastguard Worker Union, 25*523fa7a6SAndroid Build Coastguard Worker) 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Workerimport executorch.extension.pytree as ex_pytree 28*523fa7a6SAndroid Build Coastguard Workerimport torch 29*523fa7a6SAndroid Build Coastguard Workerimport torch._dynamo as torchdynamo 30*523fa7a6SAndroid Build Coastguard Workerimport torch.fx as fx 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Workerimport torch.fx._pytree as fx_pytree 33*523fa7a6SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.common import ( 36*523fa7a6SAndroid Build Coastguard Worker extract_out_arguments, 37*523fa7a6SAndroid Build Coastguard Worker format_schema_name, 38*523fa7a6SAndroid Build Coastguard Worker no_dispatch, 39*523fa7a6SAndroid Build Coastguard Worker setting_python_recursive_limit, 40*523fa7a6SAndroid Build Coastguard Worker) 41*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError, ExportErrorType, InternalError 42*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import LeafValue 43*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.operator.convert import is_out_variant 44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.types import ValueSpec 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Workerfrom torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual 47*523fa7a6SAndroid Build Coastguard Workerfrom torch._decomp import get_decompositions 48*523fa7a6SAndroid Build Coastguard Workerfrom torch._dynamo.guards import Guard 49*523fa7a6SAndroid Build Coastguard Workerfrom torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor 50*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import default_decompositions 51*523fa7a6SAndroid Build Coastguard Workerfrom torch.func import functionalize 52*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import normalize_function 53*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import TreeSpec 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard WorkerValue: TypeAlias = Union[ 59*523fa7a6SAndroid Build Coastguard Worker LeafValue, 60*523fa7a6SAndroid Build Coastguard Worker Tuple["Value", ...], 61*523fa7a6SAndroid Build Coastguard Worker List["Value"], 62*523fa7a6SAndroid Build Coastguard Worker Dict[str, "Value"], 63*523fa7a6SAndroid Build Coastguard Worker] 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Workertorchdynamo_enabled = False 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Workerdef get_stacktrace() -> List[Dict[str, str]]: 69*523fa7a6SAndroid Build Coastguard Worker """ 70*523fa7a6SAndroid Build Coastguard Worker Get the current stacktrace (between trace() and __torch_dispatch__()) 71*523fa7a6SAndroid Build Coastguard Worker Include the filename, function name, line number, and source code from the 72*523fa7a6SAndroid Build Coastguard Worker start of the function to the given instruction. 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker Return: 75*523fa7a6SAndroid Build Coastguard Worker A list of stacktraces for each instruction along with the source code 76*523fa7a6SAndroid Build Coastguard Worker context surrounding each instruction 77*523fa7a6SAndroid Build Coastguard Worker """ 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker stacktrace = traceback.extract_stack() 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker # The stacktrace typically looks like this: 82*523fa7a6SAndroid Build Coastguard Worker # 83*523fa7a6SAndroid Build Coastguard Worker # 1. I stack frames from the top level runner (e.g., the 84*523fa7a6SAndroid Build Coastguard Worker # test suite runner) 85*523fa7a6SAndroid Build Coastguard Worker # 2. J frames in executorch/exir/tracer.py setting up tracing 86*523fa7a6SAndroid Build Coastguard Worker # (call this INIT_EXIR) 87*523fa7a6SAndroid Build Coastguard Worker # 3. K frames in user model code (this is what we want to save!) 88*523fa7a6SAndroid Build Coastguard Worker # 4. 1 frame in executorch/exir/tracer.py __torch_function__ 89*523fa7a6SAndroid Build Coastguard Worker # returning to tracer (call this TRACE_EXIR) 90*523fa7a6SAndroid Build Coastguard Worker # 5. H frames in executorch/exir/tracer.py AND torch/_tensor.py 91*523fa7a6SAndroid Build Coastguard Worker # doing all of the internal tracer handling 92*523fa7a6SAndroid Build Coastguard Worker # 93*523fa7a6SAndroid Build Coastguard Worker # The PyE tests assert that executorch/exir/tracer.py never shows 94*523fa7a6SAndroid Build Coastguard Worker # up in the user provided stack traces, so we must oblige them. 95*523fa7a6SAndroid Build Coastguard Worker # 96*523fa7a6SAndroid Build Coastguard Worker # Assumptions: 97*523fa7a6SAndroid Build Coastguard Worker # - Reentrant tracing is not a thing. Thus, the first time 98*523fa7a6SAndroid Build Coastguard Worker # executorch/exir/tracer.py shows up in the trace, we know 99*523fa7a6SAndroid Build Coastguard Worker # THAT is the point at which we start tracing. (An alternative 100*523fa7a6SAndroid Build Coastguard Worker # is that the tracer entry point could record the stack trace 101*523fa7a6SAndroid Build Coastguard Worker # at this time, but I didn't do this.) 102*523fa7a6SAndroid Build Coastguard Worker # 103*523fa7a6SAndroid Build Coastguard Worker # Our plan is to do a miniature stack machine traversing these 104*523fa7a6SAndroid Build Coastguard Worker # stack machines. 105*523fa7a6SAndroid Build Coastguard Worker 106*523fa7a6SAndroid Build Coastguard Worker # Remove parts before the trace function and parts after entering 107*523fa7a6SAndroid Build Coastguard Worker # __torch_dispatch__. Defaults to returning the entire stack trace. 108*523fa7a6SAndroid Build Coastguard Worker init_exir_end = 0 109*523fa7a6SAndroid Build Coastguard Worker trace_exir_start = None 110*523fa7a6SAndroid Build Coastguard Worker # A miniature state machine, referring to the frame segments described 111*523fa7a6SAndroid Build Coastguard Worker # above. The locations are closed-open interval. 112*523fa7a6SAndroid Build Coastguard Worker FIND_INIT_EXIR_START, FIND_INIT_EXIR_END, FIND_TRACE_EXIR_START = range(3) 113*523fa7a6SAndroid Build Coastguard Worker state = FIND_INIT_EXIR_START 114*523fa7a6SAndroid Build Coastguard Worker for i, frame in enumerate(stacktrace): 115*523fa7a6SAndroid Build Coastguard Worker if state == FIND_INIT_EXIR_START: 116*523fa7a6SAndroid Build Coastguard Worker if "executorch/exir/tracer.py" in frame.filename: 117*523fa7a6SAndroid Build Coastguard Worker state = FIND_INIT_EXIR_END 118*523fa7a6SAndroid Build Coastguard Worker elif state == FIND_INIT_EXIR_END: 119*523fa7a6SAndroid Build Coastguard Worker if "executorch/exir/tracer.py" not in frame.filename: 120*523fa7a6SAndroid Build Coastguard Worker init_exir_end = i 121*523fa7a6SAndroid Build Coastguard Worker state = FIND_TRACE_EXIR_START 122*523fa7a6SAndroid Build Coastguard Worker elif state == FIND_TRACE_EXIR_START: 123*523fa7a6SAndroid Build Coastguard Worker if "executorch/exir/tracer.py" in frame.filename: 124*523fa7a6SAndroid Build Coastguard Worker trace_exir_start = i 125*523fa7a6SAndroid Build Coastguard Worker break 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker stacktrace = stacktrace[init_exir_end:trace_exir_start] 128*523fa7a6SAndroid Build Coastguard Worker 129*523fa7a6SAndroid Build Coastguard Worker # Get the source code from the errored line to it 130*523fa7a6SAndroid Build Coastguard Worker contexts: List[str] = [] 131*523fa7a6SAndroid Build Coastguard Worker for s in stacktrace: 132*523fa7a6SAndroid Build Coastguard Worker try: 133*523fa7a6SAndroid Build Coastguard Worker with open(s.filename) as file: 134*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[6]: For 1st param expected `Union[SupportsTrunc, bytes, 135*523fa7a6SAndroid Build Coastguard Worker # str, SupportsInt, SupportsIndex]` but got `Optional[int]`. 136*523fa7a6SAndroid Build Coastguard Worker lineno = int(s.lineno) 137*523fa7a6SAndroid Build Coastguard Worker # Get the source code 5 lines above/below the current instruction 138*523fa7a6SAndroid Build Coastguard Worker file_contents = [ 139*523fa7a6SAndroid Build Coastguard Worker str(index + 1) + line for index, line in enumerate(file.readlines()) 140*523fa7a6SAndroid Build Coastguard Worker ] 141*523fa7a6SAndroid Build Coastguard Worker file_contents_above = "".join( 142*523fa7a6SAndroid Build Coastguard Worker file_contents[max(lineno - 5, 0) : lineno] 143*523fa7a6SAndroid Build Coastguard Worker ) 144*523fa7a6SAndroid Build Coastguard Worker file_contents_below = "".join( 145*523fa7a6SAndroid Build Coastguard Worker file_contents[lineno : min(lineno + 5, len(file_contents))] 146*523fa7a6SAndroid Build Coastguard Worker ) 147*523fa7a6SAndroid Build Coastguard Worker context = ( 148*523fa7a6SAndroid Build Coastguard Worker file_contents_above 149*523fa7a6SAndroid Build Coastguard Worker + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" 150*523fa7a6SAndroid Build Coastguard Worker + file_contents_below 151*523fa7a6SAndroid Build Coastguard Worker ) 152*523fa7a6SAndroid Build Coastguard Worker contexts.append(context) 153*523fa7a6SAndroid Build Coastguard Worker except FileNotFoundError: 154*523fa7a6SAndroid Build Coastguard Worker contexts.append("<unknown file: unknown line>") 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker # torch.fx stack preservation logic expects strings to 157*523fa7a6SAndroid Build Coastguard Worker # be passed around. Working with dictionary is lot easier 158*523fa7a6SAndroid Build Coastguard Worker # to convert to string and vice versa. 159*523fa7a6SAndroid Build Coastguard Worker frames: List[Dict[str, str]] = [] 160*523fa7a6SAndroid Build Coastguard Worker for i, frame in enumerate(stacktrace): 161*523fa7a6SAndroid Build Coastguard Worker frames.append( 162*523fa7a6SAndroid Build Coastguard Worker { 163*523fa7a6SAndroid Build Coastguard Worker "filename": str(frame.filename), 164*523fa7a6SAndroid Build Coastguard Worker "lineno": str(frame.lineno), 165*523fa7a6SAndroid Build Coastguard Worker "name": str(frame.name), 166*523fa7a6SAndroid Build Coastguard Worker "line": str(frame.line), 167*523fa7a6SAndroid Build Coastguard Worker "context": contexts[i], 168*523fa7a6SAndroid Build Coastguard Worker } 169*523fa7a6SAndroid Build Coastguard Worker ) 170*523fa7a6SAndroid Build Coastguard Worker 171*523fa7a6SAndroid Build Coastguard Worker return frames 172*523fa7a6SAndroid Build Coastguard Worker 173*523fa7a6SAndroid Build Coastguard Worker 174*523fa7a6SAndroid Build Coastguard Workerdef unwrap_functional(t: torch.Tensor) -> torch.Tensor: 175*523fa7a6SAndroid Build Coastguard Worker assert isinstance(t, torch.Tensor) 176*523fa7a6SAndroid Build Coastguard Worker return _maybe_unwrap_functional_tensor(t, reapply_views=False) 177*523fa7a6SAndroid Build Coastguard Worker 178*523fa7a6SAndroid Build Coastguard Worker 179*523fa7a6SAndroid Build Coastguard Workerdef unwrap_proxy(t: LeafValue) -> Union[LeafValue, torch.fx.Proxy]: 180*523fa7a6SAndroid Build Coastguard Worker if not isinstance(t, torch.Tensor): 181*523fa7a6SAndroid Build Coastguard Worker return t 182*523fa7a6SAndroid Build Coastguard Worker t = unwrap_functional(t) 183*523fa7a6SAndroid Build Coastguard Worker return t.proxy if isinstance(t, PythonTensor) else t 184*523fa7a6SAndroid Build Coastguard Worker 185*523fa7a6SAndroid Build Coastguard Worker 186*523fa7a6SAndroid Build Coastguard Workerdef single_return( 187*523fa7a6SAndroid Build Coastguard Worker output: LeafValue, 188*523fa7a6SAndroid Build Coastguard Worker proxy: torch.fx.Proxy, 189*523fa7a6SAndroid Build Coastguard Worker wrapper: Callable[..., LeafValue], 190*523fa7a6SAndroid Build Coastguard Worker) -> LeafValue: 191*523fa7a6SAndroid Build Coastguard Worker if isinstance(output, torch.Tensor): 192*523fa7a6SAndroid Build Coastguard Worker return wrapper(output, proxy) 193*523fa7a6SAndroid Build Coastguard Worker 194*523fa7a6SAndroid Build Coastguard Worker return output 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Workerdef tree_return( 198*523fa7a6SAndroid Build Coastguard Worker outputs: Value, 199*523fa7a6SAndroid Build Coastguard Worker proxy: torch.fx.Proxy, 200*523fa7a6SAndroid Build Coastguard Worker wrapper: Callable[..., LeafValue], 201*523fa7a6SAndroid Build Coastguard Worker meta_type: Callable[..., Iterable[ValueSpec]] = tuple, 202*523fa7a6SAndroid Build Coastguard Worker) -> Value: 203*523fa7a6SAndroid Build Coastguard Worker i: int = 0 204*523fa7a6SAndroid Build Coastguard Worker 205*523fa7a6SAndroid Build Coastguard Worker def wrap(o: LeafValue) -> LeafValue: 206*523fa7a6SAndroid Build Coastguard Worker nonlocal i 207*523fa7a6SAndroid Build Coastguard Worker ret = single_return(o, proxy[i], wrapper) 208*523fa7a6SAndroid Build Coastguard Worker i += 1 209*523fa7a6SAndroid Build Coastguard Worker return ret 210*523fa7a6SAndroid Build Coastguard Worker 211*523fa7a6SAndroid Build Coastguard Worker return pytree.tree_map(wrap, outputs) 212*523fa7a6SAndroid Build Coastguard Worker 213*523fa7a6SAndroid Build Coastguard Worker 214*523fa7a6SAndroid Build Coastguard Workerclass DummyProxy: 215*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 216*523fa7a6SAndroid Build Coastguard Worker class DummyNode: 217*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 218*523fa7a6SAndroid Build Coastguard Worker self.meta = {} 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker self.node = DummyNode() 221*523fa7a6SAndroid Build Coastguard Worker 222*523fa7a6SAndroid Build Coastguard Worker def __getitem__(self, key: str) -> "DummyProxy": 223*523fa7a6SAndroid Build Coastguard Worker return DummyProxy() 224*523fa7a6SAndroid Build Coastguard Worker 225*523fa7a6SAndroid Build Coastguard Worker 226*523fa7a6SAndroid Build Coastguard Workerclass PythonTensor(torch.Tensor): 227*523fa7a6SAndroid Build Coastguard Worker """ 228*523fa7a6SAndroid Build Coastguard Worker A wrapper tensor subclass used in the DispatchTracer to keep track of 229*523fa7a6SAndroid Build Coastguard Worker proxies to construct the FX graph. 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Worker Wrapping something in PythonTensor implicitly detaches gradients. If 232*523fa7a6SAndroid Build Coastguard Worker something required grad, we will collect it as if it were a leaf. A 233*523fa7a6SAndroid Build Coastguard Worker consequence of detaching in this way is you need to maintain a parameter 234*523fa7a6SAndroid Build Coastguard Worker cache when translating tensors into PythonTensor, so you don't create 235*523fa7a6SAndroid Build Coastguard Worker multiple copies of a gradient (they are aliased, but they would count as 236*523fa7a6SAndroid Build Coastguard Worker independent leaves). An alternate strategy would be to avoid implicitly 237*523fa7a6SAndroid Build Coastguard Worker detaching and instead "catch" gradients as they exit the PythonTensor 238*523fa7a6SAndroid Build Coastguard Worker boundary. 239*523fa7a6SAndroid Build Coastguard Worker """ 240*523fa7a6SAndroid Build Coastguard Worker 241*523fa7a6SAndroid Build Coastguard Worker __slots__ = ["proxy", "is_immutable"] 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker @staticmethod 244*523fa7a6SAndroid Build Coastguard Worker def __new__( 245*523fa7a6SAndroid Build Coastguard Worker cls, elem: torch.Tensor, proxy: torch.fx.Proxy, is_immutable: bool = False 246*523fa7a6SAndroid Build Coastguard Worker ) -> torch.Tensor: 247*523fa7a6SAndroid Build Coastguard Worker # assert not elem.requires_grad or not torch.is_grad_enabled() 248*523fa7a6SAndroid Build Coastguard Worker 249*523fa7a6SAndroid Build Coastguard Worker r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 250*523fa7a6SAndroid Build Coastguard Worker assert isinstance(r, PythonTensor) 251*523fa7a6SAndroid Build Coastguard Worker r.is_immutable: bool = is_immutable 252*523fa7a6SAndroid Build Coastguard Worker r.update_proxy(proxy) 253*523fa7a6SAndroid Build Coastguard Worker return r 254*523fa7a6SAndroid Build Coastguard Worker 255*523fa7a6SAndroid Build Coastguard Worker def update_proxy(self, proxy: torch.fx.Proxy) -> None: 256*523fa7a6SAndroid Build Coastguard Worker self.proxy = proxy 257*523fa7a6SAndroid Build Coastguard Worker 258*523fa7a6SAndroid Build Coastguard Worker def __repr__(self, *, tensor_contents: None = None) -> str: 259*523fa7a6SAndroid Build Coastguard Worker with no_dispatch(): 260*523fa7a6SAndroid Build Coastguard Worker return f"PythonTensor({self.as_subclass(torch.Tensor)})" 261*523fa7a6SAndroid Build Coastguard Worker 262*523fa7a6SAndroid Build Coastguard Worker @classmethod 263*523fa7a6SAndroid Build Coastguard Worker def __torch_function__( 264*523fa7a6SAndroid Build Coastguard Worker cls, 265*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore: Missing parameter annotation [2] 266*523fa7a6SAndroid Build Coastguard Worker func, 267*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore: Missing parameter annotation [2] 268*523fa7a6SAndroid Build Coastguard Worker types, 269*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Value, ...] = (), 270*523fa7a6SAndroid Build Coastguard Worker kwargs: Optional[Dict[str, Value]] = None, 271*523fa7a6SAndroid Build Coastguard Worker ) -> Value: 272*523fa7a6SAndroid Build Coastguard Worker if kwargs is None: 273*523fa7a6SAndroid Build Coastguard Worker kwargs = {} 274*523fa7a6SAndroid Build Coastguard Worker if torch.is_inference_mode_enabled(): 275*523fa7a6SAndroid Build Coastguard Worker if func is torch.nn.functional.layer_norm: 276*523fa7a6SAndroid Build Coastguard Worker args, kwargs = normalize_function(func, args, kwargs) # pyre-fixme[23] 277*523fa7a6SAndroid Build Coastguard Worker input, normalized_shape = args 278*523fa7a6SAndroid Build Coastguard Worker normalized_shape = list(normalized_shape) 279*523fa7a6SAndroid Build Coastguard Worker return cls.__torch_dispatch__( 280*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.layer_norm.default, 281*523fa7a6SAndroid Build Coastguard Worker types, 282*523fa7a6SAndroid Build Coastguard Worker (input, normalized_shape), 283*523fa7a6SAndroid Build Coastguard Worker kwargs, 284*523fa7a6SAndroid Build Coastguard Worker ) 285*523fa7a6SAndroid Build Coastguard Worker elif func is torch.nn.functional.linear: 286*523fa7a6SAndroid Build Coastguard Worker return cls.__torch_dispatch__( 287*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.linear.default, types, args, kwargs 288*523fa7a6SAndroid Build Coastguard Worker ) 289*523fa7a6SAndroid Build Coastguard Worker with DisableTorchFunctionSubclass(): 290*523fa7a6SAndroid Build Coastguard Worker return func(*args, **kwargs) 291*523fa7a6SAndroid Build Coastguard Worker 292*523fa7a6SAndroid Build Coastguard Worker @classmethod 293*523fa7a6SAndroid Build Coastguard Worker def __torch_dispatch__( # noqa: C901 294*523fa7a6SAndroid Build Coastguard Worker cls, 295*523fa7a6SAndroid Build Coastguard Worker func_overload: torch._ops.OpOverload, 296*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore: Missing parameter annotation [2] 297*523fa7a6SAndroid Build Coastguard Worker types, 298*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Value, ...] = (), 299*523fa7a6SAndroid Build Coastguard Worker kwargs: Optional[Dict[str, Value]] = None, 300*523fa7a6SAndroid Build Coastguard Worker ) -> Value: 301*523fa7a6SAndroid Build Coastguard Worker """ 302*523fa7a6SAndroid Build Coastguard Worker This function is invoked every time an aten operation is called. 303*523fa7a6SAndroid Build Coastguard Worker 304*523fa7a6SAndroid Build Coastguard Worker Args: 305*523fa7a6SAndroid Build Coastguard Worker func_overload: The function that was called that invoked this 306*523fa7a6SAndroid Build Coastguard Worker torch_dispatch call 307*523fa7a6SAndroid Build Coastguard Worker types: 308*523fa7a6SAndroid Build Coastguard Worker args: Arguments that were passed into the function. Each argument 309*523fa7a6SAndroid Build Coastguard Worker has type PythonTensor. 310*523fa7a6SAndroid Build Coastguard Worker kwargs: Keyword arguments that were passed into the function. Each 311*523fa7a6SAndroid Build Coastguard Worker argument has type PythonTensor. 312*523fa7a6SAndroid Build Coastguard Worker """ 313*523fa7a6SAndroid Build Coastguard Worker func = func_overload.overloadpacket 314*523fa7a6SAndroid Build Coastguard Worker 315*523fa7a6SAndroid Build Coastguard Worker kwargs = kwargs or {} 316*523fa7a6SAndroid Build Coastguard Worker if is_out_variant(func._qualified_op_name, func_overload._overloadname): 317*523fa7a6SAndroid Build Coastguard Worker out_args = extract_out_arguments(func_overload._schema, kwargs) 318*523fa7a6SAndroid Build Coastguard Worker out_args_iter = [out_args] if not isinstance(out_args, list) else out_args 319*523fa7a6SAndroid Build Coastguard Worker for out_arg_name, out_arg_val in out_args_iter: 320*523fa7a6SAndroid Build Coastguard Worker if isinstance(out_arg_val, PythonTensor) and out_arg_val.is_immutable: 321*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 322*523fa7a6SAndroid Build Coastguard Worker "Immutable tensor `{}` is potentially getting modified by {}".format( 323*523fa7a6SAndroid Build Coastguard Worker out_arg_name, format_schema_name(func_overload._schema) 324*523fa7a6SAndroid Build Coastguard Worker ) 325*523fa7a6SAndroid Build Coastguard Worker ) 326*523fa7a6SAndroid Build Coastguard Worker 327*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`. 328*523fa7a6SAndroid Build Coastguard Worker proxy_args = ex_pytree.tree_map(unwrap_proxy, args) 329*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`. 330*523fa7a6SAndroid Build Coastguard Worker proxy_kwargs = ex_pytree.tree_map(unwrap_proxy, kwargs) 331*523fa7a6SAndroid Build Coastguard Worker 332*523fa7a6SAndroid Build Coastguard Worker # Get the output of the function 333*523fa7a6SAndroid Build Coastguard Worker g = _EnableTorchFunction() 334*523fa7a6SAndroid Build Coastguard Worker try: 335*523fa7a6SAndroid Build Coastguard Worker proxy_out = ( 336*523fa7a6SAndroid Build Coastguard Worker func_overload(*proxy_args, **proxy_kwargs) 337*523fa7a6SAndroid Build Coastguard Worker if DispatchTracer.get() or torchdynamo_enabled 338*523fa7a6SAndroid Build Coastguard Worker # Disable node creation when no tracer is active. 339*523fa7a6SAndroid Build Coastguard Worker else DummyProxy() 340*523fa7a6SAndroid Build Coastguard Worker ) 341*523fa7a6SAndroid Build Coastguard Worker finally: 342*523fa7a6SAndroid Build Coastguard Worker del g 343*523fa7a6SAndroid Build Coastguard Worker 344*523fa7a6SAndroid Build Coastguard Worker with no_dispatch(): 345*523fa7a6SAndroid Build Coastguard Worker real_out = func_overload(*args, **kwargs) 346*523fa7a6SAndroid Build Coastguard Worker 347*523fa7a6SAndroid Build Coastguard Worker # Kind of a hacky way to test if an op is in-place or not 348*523fa7a6SAndroid Build Coastguard Worker if func.__name__[-1] == "_" and func.__name__[0] != "_": 349*523fa7a6SAndroid Build Coastguard Worker if isinstance(args[0], PythonTensor): 350*523fa7a6SAndroid Build Coastguard Worker args[0].proxy = proxy_out 351*523fa7a6SAndroid Build Coastguard Worker 352*523fa7a6SAndroid Build Coastguard Worker if not torch.fx.traceback.has_preserved_node_meta(): 353*523fa7a6SAndroid Build Coastguard Worker proxy_out.node.meta["stack_trace"] = json.dumps(get_stacktrace()) 354*523fa7a6SAndroid Build Coastguard Worker 355*523fa7a6SAndroid Build Coastguard Worker # Wrap the output tensors with the PythonTensor subclass to propagate to 356*523fa7a6SAndroid Build Coastguard Worker # future tracing 357*523fa7a6SAndroid Build Coastguard Worker def wrap_with_proxy(e: LeafValue, proxy: torch.fx.Proxy) -> LeafValue: 358*523fa7a6SAndroid Build Coastguard Worker # Some ops (like native_batch_norm_backward) return undefined tensors that get 359*523fa7a6SAndroid Build Coastguard Worker # converted into None in python. 360*523fa7a6SAndroid Build Coastguard Worker # As the function signature expects tensors, if we directly return these None 361*523fa7a6SAndroid Build Coastguard Worker # tensors back to C++, we'll error. 362*523fa7a6SAndroid Build Coastguard Worker if e is None: 363*523fa7a6SAndroid Build Coastguard Worker e = torch.empty(()) 364*523fa7a6SAndroid Build Coastguard Worker 365*523fa7a6SAndroid Build Coastguard Worker if isinstance(e, torch.Tensor): 366*523fa7a6SAndroid Build Coastguard Worker return PythonTensor(e, proxy) 367*523fa7a6SAndroid Build Coastguard Worker 368*523fa7a6SAndroid Build Coastguard Worker # Inplace and out-variant ops may return one of their arguments, which is already 369*523fa7a6SAndroid Build Coastguard Worker # a PythonTensor. In this case, we need to update the PythonTensor's associated 370*523fa7a6SAndroid Build Coastguard Worker # proxy to the newly created proxy. 371*523fa7a6SAndroid Build Coastguard Worker if isinstance(e, PythonTensor): 372*523fa7a6SAndroid Build Coastguard Worker e.update_proxy(proxy) 373*523fa7a6SAndroid Build Coastguard Worker return e 374*523fa7a6SAndroid Build Coastguard Worker 375*523fa7a6SAndroid Build Coastguard Worker return e 376*523fa7a6SAndroid Build Coastguard Worker 377*523fa7a6SAndroid Build Coastguard Worker retval = None 378*523fa7a6SAndroid Build Coastguard Worker if not isinstance(real_out, (list, tuple)): 379*523fa7a6SAndroid Build Coastguard Worker retval = single_return(real_out, proxy_out, wrap_with_proxy) 380*523fa7a6SAndroid Build Coastguard Worker else: 381*523fa7a6SAndroid Build Coastguard Worker retval = tree_return(real_out, proxy_out, wrap_with_proxy, type(real_out)) 382*523fa7a6SAndroid Build Coastguard Worker return retval 383*523fa7a6SAndroid Build Coastguard Worker 384*523fa7a6SAndroid Build Coastguard Worker 385*523fa7a6SAndroid Build Coastguard Worker@contextmanager 386*523fa7a6SAndroid Build Coastguard Workerdef using_tracer(tracer: Optional["DispatchTracer"]) -> Generator[None, None, None]: 387*523fa7a6SAndroid Build Coastguard Worker """ 388*523fa7a6SAndroid Build Coastguard Worker Set the "current" global tracer within the scope of using_tracer 389*523fa7a6SAndroid Build Coastguard Worker context manager. 390*523fa7a6SAndroid Build Coastguard Worker 391*523fa7a6SAndroid Build Coastguard Worker Since various things we want to capture today with torch_dispatch 392*523fa7a6SAndroid Build Coastguard Worker does not "trap" into dispatcher really (for example, cond() and 393*523fa7a6SAndroid Build Coastguard Worker shape()), we need a separate singleton tracer exposed to user space 394*523fa7a6SAndroid Build Coastguard Worker in addition to Dispatcher to trigger graph capturing. 395*523fa7a6SAndroid Build Coastguard Worker """ 396*523fa7a6SAndroid Build Coastguard Worker global TRACER 397*523fa7a6SAndroid Build Coastguard Worker TRACER, prev = tracer, TRACER 398*523fa7a6SAndroid Build Coastguard Worker try: 399*523fa7a6SAndroid Build Coastguard Worker yield 400*523fa7a6SAndroid Build Coastguard Worker finally: 401*523fa7a6SAndroid Build Coastguard Worker TRACER = prev 402*523fa7a6SAndroid Build Coastguard Worker 403*523fa7a6SAndroid Build Coastguard Worker 404*523fa7a6SAndroid Build Coastguard Workerclass DispatchTracer(fx.Tracer): 405*523fa7a6SAndroid Build Coastguard Worker def __init__(self) -> None: 406*523fa7a6SAndroid Build Coastguard Worker super().__init__() 407*523fa7a6SAndroid Build Coastguard Worker self.root: torch.nn.Module = torch.nn.Module() 408*523fa7a6SAndroid Build Coastguard Worker self.tensor_attrs: Dict[torch.Tensor, str] = {} 409*523fa7a6SAndroid Build Coastguard Worker self.submodules: Dict[fx.GraphModule, str] = {} 410*523fa7a6SAndroid Build Coastguard Worker 411*523fa7a6SAndroid Build Coastguard Worker def call_module( 412*523fa7a6SAndroid Build Coastguard Worker self, 413*523fa7a6SAndroid Build Coastguard Worker m: torch.nn.Module, 414*523fa7a6SAndroid Build Coastguard Worker forward: Callable[..., Value], 415*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Value, ...], 416*523fa7a6SAndroid Build Coastguard Worker kwargs: Dict[str, Value], 417*523fa7a6SAndroid Build Coastguard Worker ) -> Value: 418*523fa7a6SAndroid Build Coastguard Worker return forward(*args, **kwargs) 419*523fa7a6SAndroid Build Coastguard Worker 420*523fa7a6SAndroid Build Coastguard Worker def _module_getattr( 421*523fa7a6SAndroid Build Coastguard Worker self, attr: str, attr_val: Value, parameter_proxy_cache: Dict[str, torch.Tensor] 422*523fa7a6SAndroid Build Coastguard Worker ) -> Value: 423*523fa7a6SAndroid Build Coastguard Worker if isinstance(attr_val, torch.nn.Parameter): 424*523fa7a6SAndroid Build Coastguard Worker for n, p in self.root.named_parameters(): 425*523fa7a6SAndroid Build Coastguard Worker if attr_val is p: 426*523fa7a6SAndroid Build Coastguard Worker if n not in parameter_proxy_cache: 427*523fa7a6SAndroid Build Coastguard Worker proxy = self.create_proxy("get_attr", n, (), {}) 428*523fa7a6SAndroid Build Coastguard Worker parameter_proxy_cache[n] = PythonTensor(attr_val, proxy) 429*523fa7a6SAndroid Build Coastguard Worker return parameter_proxy_cache[n] 430*523fa7a6SAndroid Build Coastguard Worker return attr_val 431*523fa7a6SAndroid Build Coastguard Worker return attr_val 432*523fa7a6SAndroid Build Coastguard Worker 433*523fa7a6SAndroid Build Coastguard Worker def create_arg(self, a: Value) -> torch.fx.Node: # noqa: C901 434*523fa7a6SAndroid Build Coastguard Worker if isinstance(a, torch.nn.Parameter): 435*523fa7a6SAndroid Build Coastguard Worker for n, p in self.root.named_parameters(): 436*523fa7a6SAndroid Build Coastguard Worker if a is p: 437*523fa7a6SAndroid Build Coastguard Worker return self.create_node("get_attr", n, (), {}) 438*523fa7a6SAndroid Build Coastguard Worker qualname: Optional[str] = None 439*523fa7a6SAndroid Build Coastguard Worker 440*523fa7a6SAndroid Build Coastguard Worker if not qualname: 441*523fa7a6SAndroid Build Coastguard Worker i = 0 442*523fa7a6SAndroid Build Coastguard Worker while True: 443*523fa7a6SAndroid Build Coastguard Worker qualname = f"_param_constant{i}" 444*523fa7a6SAndroid Build Coastguard Worker if not hasattr(self.root, qualname): 445*523fa7a6SAndroid Build Coastguard Worker break 446*523fa7a6SAndroid Build Coastguard Worker i += 1 447*523fa7a6SAndroid Build Coastguard Worker setattr(self.root, qualname, a) 448*523fa7a6SAndroid Build Coastguard Worker 449*523fa7a6SAndroid Build Coastguard Worker return self.create_node("get_attr", qualname, (), {}) 450*523fa7a6SAndroid Build Coastguard Worker 451*523fa7a6SAndroid Build Coastguard Worker if isinstance(a, torch.Tensor): 452*523fa7a6SAndroid Build Coastguard Worker qualname: Optional[str] = self.tensor_attrs.get(a) 453*523fa7a6SAndroid Build Coastguard Worker 454*523fa7a6SAndroid Build Coastguard Worker if not qualname: 455*523fa7a6SAndroid Build Coastguard Worker i = 0 456*523fa7a6SAndroid Build Coastguard Worker while True: 457*523fa7a6SAndroid Build Coastguard Worker qualname = f"_tensor_constant{i}" 458*523fa7a6SAndroid Build Coastguard Worker if not hasattr(self.root, qualname): 459*523fa7a6SAndroid Build Coastguard Worker break 460*523fa7a6SAndroid Build Coastguard Worker i += 1 461*523fa7a6SAndroid Build Coastguard Worker self.tensor_attrs[a] = qualname 462*523fa7a6SAndroid Build Coastguard Worker self.root.register_buffer(qualname, a) 463*523fa7a6SAndroid Build Coastguard Worker 464*523fa7a6SAndroid Build Coastguard Worker return self.create_node("get_attr", qualname, (), {}) 465*523fa7a6SAndroid Build Coastguard Worker 466*523fa7a6SAndroid Build Coastguard Worker # higher-order operator 467*523fa7a6SAndroid Build Coastguard Worker if isinstance(a, fx.GraphModule): 468*523fa7a6SAndroid Build Coastguard Worker if a not in self.submodules: 469*523fa7a6SAndroid Build Coastguard Worker name_submodule = f"submodule_{len(self.submodules)}" 470*523fa7a6SAndroid Build Coastguard Worker self.root.add_module(name_submodule, a) 471*523fa7a6SAndroid Build Coastguard Worker self.submodules[a] = name_submodule 472*523fa7a6SAndroid Build Coastguard Worker return self.create_node("get_attr", self.submodules[a], (), {}) 473*523fa7a6SAndroid Build Coastguard Worker 474*523fa7a6SAndroid Build Coastguard Worker return super().create_arg(a) # pyre-fixme[7] 475*523fa7a6SAndroid Build Coastguard Worker 476*523fa7a6SAndroid Build Coastguard Worker @staticmethod 477*523fa7a6SAndroid Build Coastguard Worker def get() -> "DispatchTracer": 478*523fa7a6SAndroid Build Coastguard Worker return TRACER 479*523fa7a6SAndroid Build Coastguard Worker 480*523fa7a6SAndroid Build Coastguard Worker def trace( # pyre-fixme[14,15] 481*523fa7a6SAndroid Build Coastguard Worker self, 482*523fa7a6SAndroid Build Coastguard Worker root: Callable[..., Value], 483*523fa7a6SAndroid Build Coastguard Worker concrete_args: Tuple[Value, ...] = (), 484*523fa7a6SAndroid Build Coastguard Worker in_spec: Optional[TreeSpec] = None, 485*523fa7a6SAndroid Build Coastguard Worker ) -> Value: 486*523fa7a6SAndroid Build Coastguard Worker """ 487*523fa7a6SAndroid Build Coastguard Worker Traces the given graph module. 488*523fa7a6SAndroid Build Coastguard Worker """ 489*523fa7a6SAndroid Build Coastguard Worker with using_tracer(self): 490*523fa7a6SAndroid Build Coastguard Worker return self._trace(root, concrete_args=concrete_args, in_spec=in_spec) 491*523fa7a6SAndroid Build Coastguard Worker 492*523fa7a6SAndroid Build Coastguard Worker def _trace( 493*523fa7a6SAndroid Build Coastguard Worker self, 494*523fa7a6SAndroid Build Coastguard Worker root: Callable[..., Value], 495*523fa7a6SAndroid Build Coastguard Worker concrete_args: Tuple[Value, ...], 496*523fa7a6SAndroid Build Coastguard Worker in_spec: Optional[TreeSpec], 497*523fa7a6SAndroid Build Coastguard Worker ) -> Value: 498*523fa7a6SAndroid Build Coastguard Worker self.root = torch.nn.Module() 499*523fa7a6SAndroid Build Coastguard Worker root_fn = root 500*523fa7a6SAndroid Build Coastguard Worker 501*523fa7a6SAndroid Build Coastguard Worker tracer_cls = getattr(self, "__class__", None) 502*523fa7a6SAndroid Build Coastguard Worker self.graph = fx.Graph(tracer_cls=tracer_cls) 503*523fa7a6SAndroid Build Coastguard Worker # Don't support module, so tensor_attrs is always empty 504*523fa7a6SAndroid Build Coastguard Worker self.tensor_attrs = {} 505*523fa7a6SAndroid Build Coastguard Worker 506*523fa7a6SAndroid Build Coastguard Worker # Wrap all inputs as a PythonTensor subclass and insert them into the FX 507*523fa7a6SAndroid Build Coastguard Worker # graph as placeholder nodes 508*523fa7a6SAndroid Build Coastguard Worker def wrap(arg: Value, i: int) -> Value: 509*523fa7a6SAndroid Build Coastguard Worker placeholder = self.create_proxy("placeholder", f"ph_{i}", (), {}) 510*523fa7a6SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor): 511*523fa7a6SAndroid Build Coastguard Worker return PythonTensor(arg, placeholder, is_immutable=True) 512*523fa7a6SAndroid Build Coastguard Worker else: 513*523fa7a6SAndroid Build Coastguard Worker # torch._assert( 514*523fa7a6SAndroid Build Coastguard Worker # placeholder == arg, 515*523fa7a6SAndroid Build Coastguard Worker # f"ph_{i} has been specialized to have value {arg}", 516*523fa7a6SAndroid Build Coastguard Worker # ) 517*523fa7a6SAndroid Build Coastguard Worker return arg 518*523fa7a6SAndroid Build Coastguard Worker 519*523fa7a6SAndroid Build Coastguard Worker tree_args = [wrap(arg, i) for i, arg in enumerate(concrete_args)] 520*523fa7a6SAndroid Build Coastguard Worker if in_spec: 521*523fa7a6SAndroid Build Coastguard Worker tree_args = pytree.tree_unflatten(tree_args, in_spec) 522*523fa7a6SAndroid Build Coastguard Worker 523*523fa7a6SAndroid Build Coastguard Worker tree_out = root_fn(*tree_args) 524*523fa7a6SAndroid Build Coastguard Worker 525*523fa7a6SAndroid Build Coastguard Worker out_args, _ = pytree.tree_flatten(tree_out) 526*523fa7a6SAndroid Build Coastguard Worker 527*523fa7a6SAndroid Build Coastguard Worker def unwrap(out: LeafValue) -> Union[LeafValue, torch.fx.Proxy]: 528*523fa7a6SAndroid Build Coastguard Worker # it's legit for a model to return a list of items some of which 529*523fa7a6SAndroid Build Coastguard Worker # are None 530*523fa7a6SAndroid Build Coastguard Worker if out is None: 531*523fa7a6SAndroid Build Coastguard Worker return None 532*523fa7a6SAndroid Build Coastguard Worker if not isinstance(out, torch.Tensor): 533*523fa7a6SAndroid Build Coastguard Worker raise TypeError( 534*523fa7a6SAndroid Build Coastguard Worker f"Expect model to return torch.Tensor, got type: '{type(out)}' (value: {out})." 535*523fa7a6SAndroid Build Coastguard Worker ) 536*523fa7a6SAndroid Build Coastguard Worker return unwrap_proxy(out) 537*523fa7a6SAndroid Build Coastguard Worker 538*523fa7a6SAndroid Build Coastguard Worker returns = [unwrap(out) for out in out_args] 539*523fa7a6SAndroid Build Coastguard Worker 540*523fa7a6SAndroid Build Coastguard Worker return_annotation = None 541*523fa7a6SAndroid Build Coastguard Worker # some ops like torch.sub doesn't have annotations 542*523fa7a6SAndroid Build Coastguard Worker if hasattr(root_fn, "__annotations__"): 543*523fa7a6SAndroid Build Coastguard Worker return_annotation = root_fn.__annotations__.get("return", None) 544*523fa7a6SAndroid Build Coastguard Worker 545*523fa7a6SAndroid Build Coastguard Worker self.create_proxy( 546*523fa7a6SAndroid Build Coastguard Worker "output", 547*523fa7a6SAndroid Build Coastguard Worker "output", 548*523fa7a6SAndroid Build Coastguard Worker (returns,), 549*523fa7a6SAndroid Build Coastguard Worker {}, 550*523fa7a6SAndroid Build Coastguard Worker type_expr=return_annotation, 551*523fa7a6SAndroid Build Coastguard Worker ) 552*523fa7a6SAndroid Build Coastguard Worker 553*523fa7a6SAndroid Build Coastguard Worker self.submodule_paths = None 554*523fa7a6SAndroid Build Coastguard Worker 555*523fa7a6SAndroid Build Coastguard Worker return tree_out 556*523fa7a6SAndroid Build Coastguard Worker 557*523fa7a6SAndroid Build Coastguard Worker 558*523fa7a6SAndroid Build Coastguard WorkerTRACER: Optional[DispatchTracer] = None 559*523fa7a6SAndroid Build Coastguard WorkerTORCHDYNAMO_ENABLED: bool = False 560*523fa7a6SAndroid Build Coastguard Worker 561*523fa7a6SAndroid Build Coastguard Worker 562*523fa7a6SAndroid Build Coastguard Worker@contextmanager 563*523fa7a6SAndroid Build Coastguard Workerdef using_dynamo(val: bool) -> Generator[None, None, None]: 564*523fa7a6SAndroid Build Coastguard Worker global TORCHDYNAMO_ENABLED 565*523fa7a6SAndroid Build Coastguard Worker TORCHDYNAMO_ENABLED, prev = val, TORCHDYNAMO_ENABLED 566*523fa7a6SAndroid Build Coastguard Worker try: 567*523fa7a6SAndroid Build Coastguard Worker yield 568*523fa7a6SAndroid Build Coastguard Worker finally: 569*523fa7a6SAndroid Build Coastguard Worker TORCHDYNAMO_ENABLED = prev 570*523fa7a6SAndroid Build Coastguard Worker 571*523fa7a6SAndroid Build Coastguard Worker 572*523fa7a6SAndroid Build Coastguard Workerdef flattened_dispatch_trace( 573*523fa7a6SAndroid Build Coastguard Worker f: Callable[..., Value], 574*523fa7a6SAndroid Build Coastguard Worker args: Tuple[LeafValue, ...], 575*523fa7a6SAndroid Build Coastguard Worker guards: Set[Guard], 576*523fa7a6SAndroid Build Coastguard Worker in_spec: Optional[TreeSpec] = None, 577*523fa7a6SAndroid Build Coastguard Worker enable_functionalization: bool = True, 578*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.fx.GraphModule, Value]: 579*523fa7a6SAndroid Build Coastguard Worker if not isinstance(args, tuple): 580*523fa7a6SAndroid Build Coastguard Worker raise TypeError(f"Expecting 'args' to be a tuple, got: {type(args)}") 581*523fa7a6SAndroid Build Coastguard Worker 582*523fa7a6SAndroid Build Coastguard Worker tracer = DispatchTracer() 583*523fa7a6SAndroid Build Coastguard Worker 584*523fa7a6SAndroid Build Coastguard Worker if enable_functionalization: 585*523fa7a6SAndroid Build Coastguard Worker f = functionalize(f, remove="mutations_and_views") 586*523fa7a6SAndroid Build Coastguard Worker tree_out = tracer.trace(f, concrete_args=args, in_spec=in_spec) 587*523fa7a6SAndroid Build Coastguard Worker 588*523fa7a6SAndroid Build Coastguard Worker name = type(f).__name__ if isinstance(f, torch.nn.Module) else f.__name__ 589*523fa7a6SAndroid Build Coastguard Worker gm = torch.fx.GraphModule(tracer.root, tracer.graph, name) 590*523fa7a6SAndroid Build Coastguard Worker 591*523fa7a6SAndroid Build Coastguard Worker return (gm, tree_out) 592*523fa7a6SAndroid Build Coastguard Worker 593*523fa7a6SAndroid Build Coastguard Worker 594*523fa7a6SAndroid Build Coastguard Worker@dataclass 595*523fa7a6SAndroid Build Coastguard Workerclass ExirDynamoConfig: 596*523fa7a6SAndroid Build Coastguard Worker """ 597*523fa7a6SAndroid Build Coastguard Worker Manage Exir-specific configurations of Dynamo. 598*523fa7a6SAndroid Build Coastguard Worker """ 599*523fa7a6SAndroid Build Coastguard Worker 600*523fa7a6SAndroid Build Coastguard Worker allow_rnn: bool = True 601*523fa7a6SAndroid Build Coastguard Worker verbose: bool = True 602*523fa7a6SAndroid Build Coastguard Worker assume_static_by_default: bool = False 603*523fa7a6SAndroid Build Coastguard Worker 604*523fa7a6SAndroid Build Coastguard Worker 605*523fa7a6SAndroid Build Coastguard Workerdef flatten_output(gm: torch.fx.GraphModule) -> None: 606*523fa7a6SAndroid Build Coastguard Worker """ 607*523fa7a6SAndroid Build Coastguard Worker Modifies the output nodes in the submodules to return the result 608*523fa7a6SAndroid Build Coastguard Worker as a flattened list. This keeps it consistent with the result of 609*523fa7a6SAndroid Build Coastguard Worker EXIR's tracer 610*523fa7a6SAndroid Build Coastguard Worker """ 611*523fa7a6SAndroid Build Coastguard Worker for node in reversed(gm.graph.nodes): 612*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 613*523fa7a6SAndroid Build Coastguard Worker assert len(node.args) == 1 614*523fa7a6SAndroid Build Coastguard Worker outputs = node.args[0] 615*523fa7a6SAndroid Build Coastguard Worker returns, _ = pytree.tree_flatten(outputs) 616*523fa7a6SAndroid Build Coastguard Worker node.args = (returns,) 617*523fa7a6SAndroid Build Coastguard Worker return 618*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Could not find an output node in {gm.graph}") 619*523fa7a6SAndroid Build Coastguard Worker 620*523fa7a6SAndroid Build Coastguard Worker 621*523fa7a6SAndroid Build Coastguard Workerdef _default_decomposition_table( 622*523fa7a6SAndroid Build Coastguard Worker _use_old_decomp_table=False, 623*523fa7a6SAndroid Build Coastguard Worker) -> Dict[torch._ops.OpOverload, Callable[..., Value]]: 624*523fa7a6SAndroid Build Coastguard Worker if _use_old_decomp_table: 625*523fa7a6SAndroid Build Coastguard Worker decomp_opset = [ 626*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.log_sigmoid_forward, 627*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.ones, 628*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.arange.default, 629*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.arange.start, 630*523fa7a6SAndroid Build Coastguard Worker torch.ops.aten.transpose, 631*523fa7a6SAndroid Build Coastguard Worker ] 632*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... 633*523fa7a6SAndroid Build Coastguard Worker return get_decompositions(decomp_opset) 634*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... 635*523fa7a6SAndroid Build Coastguard Worker return default_decompositions() 636*523fa7a6SAndroid Build Coastguard Worker 637*523fa7a6SAndroid Build Coastguard Worker 638*523fa7a6SAndroid Build Coastguard Workerdef dynamo_trace( 639*523fa7a6SAndroid Build Coastguard Worker f: Callable[..., Value], 640*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 641*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Any, ...], 642*523fa7a6SAndroid Build Coastguard Worker aten_graph: bool, 643*523fa7a6SAndroid Build Coastguard Worker tracing_mode: str = "real", 644*523fa7a6SAndroid Build Coastguard Worker dynamo_config: Optional[ExirDynamoConfig] = None, 645*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 646*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes: Optional[List[Any]] = None, 647*523fa7a6SAndroid Build Coastguard Worker _use_old_decomp_table: bool = False, 648*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.fx.GraphModule, Set[Guard]]: 649*523fa7a6SAndroid Build Coastguard Worker """ 650*523fa7a6SAndroid Build Coastguard Worker TODO: Once we fully migrate to torchdynamo frontend, we will remove 651*523fa7a6SAndroid Build Coastguard Worker this config option alltogether. For now, it helps with quick 652*523fa7a6SAndroid Build Coastguard Worker experiments with playing around with TorchDynamo 653*523fa7a6SAndroid Build Coastguard Worker """ 654*523fa7a6SAndroid Build Coastguard Worker if dynamo_config is None: 655*523fa7a6SAndroid Build Coastguard Worker dynamo_config = ExirDynamoConfig() 656*523fa7a6SAndroid Build Coastguard Worker 657*523fa7a6SAndroid Build Coastguard Worker with torchdynamo.config.patch( 658*523fa7a6SAndroid Build Coastguard Worker asdict(dynamo_config) 659*523fa7a6SAndroid Build Coastguard Worker ), setting_python_recursive_limit(2000): 660*523fa7a6SAndroid Build Coastguard Worker torchdynamo.reset() 661*523fa7a6SAndroid Build Coastguard Worker try: 662*523fa7a6SAndroid Build Coastguard Worker # TODO merge executorch functionalization with official 663*523fa7a6SAndroid Build Coastguard Worker # functionalization 664*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[7]: Expected `Tuple[GraphModule, Set[Guard]]` but got 665*523fa7a6SAndroid Build Coastguard Worker # `ExportResult`. 666*523fa7a6SAndroid Build Coastguard Worker return torchdynamo.export( 667*523fa7a6SAndroid Build Coastguard Worker f, 668*523fa7a6SAndroid Build Coastguard Worker aten_graph=aten_graph, 669*523fa7a6SAndroid Build Coastguard Worker tracing_mode=tracing_mode, 670*523fa7a6SAndroid Build Coastguard Worker assume_static_by_default=dynamo_config.assume_static_by_default, 671*523fa7a6SAndroid Build Coastguard Worker decomposition_table=( 672*523fa7a6SAndroid Build Coastguard Worker _default_decomposition_table(_use_old_decomp_table) 673*523fa7a6SAndroid Build Coastguard Worker if aten_graph 674*523fa7a6SAndroid Build Coastguard Worker else None 675*523fa7a6SAndroid Build Coastguard Worker ), 676*523fa7a6SAndroid Build Coastguard Worker dynamic_shapes=dynamic_shapes, 677*523fa7a6SAndroid Build Coastguard Worker )( 678*523fa7a6SAndroid Build Coastguard Worker *copy.deepcopy(args), 679*523fa7a6SAndroid Build Coastguard Worker ) 680*523fa7a6SAndroid Build Coastguard Worker except torchdynamo.exc.Unsupported as exc: 681*523fa7a6SAndroid Build Coastguard Worker raise ExportError( 682*523fa7a6SAndroid Build Coastguard Worker ExportErrorType.NOT_SUPPORTED, 683*523fa7a6SAndroid Build Coastguard Worker "The user code is using a feature we don't support. " 684*523fa7a6SAndroid Build Coastguard Worker "Please try torchdynamo.explain() to get possible the reasons", 685*523fa7a6SAndroid Build Coastguard Worker ) from exc 686*523fa7a6SAndroid Build Coastguard Worker except Exception as exc: 687*523fa7a6SAndroid Build Coastguard Worker raise InternalError( 688*523fa7a6SAndroid Build Coastguard Worker "torchdynamo internal error occured. Please see above stacktrace" 689*523fa7a6SAndroid Build Coastguard Worker ) from exc 690*523fa7a6SAndroid Build Coastguard Worker 691*523fa7a6SAndroid Build Coastguard Worker 692*523fa7a6SAndroid Build Coastguard Workerdef dispatch_trace( 693*523fa7a6SAndroid Build Coastguard Worker f: Callable[..., Value], 694*523fa7a6SAndroid Build Coastguard Worker args: Tuple[Value, ...], 695*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule: 696*523fa7a6SAndroid Build Coastguard Worker """ 697*523fa7a6SAndroid Build Coastguard Worker Executes a given callable `f` with a given tuple of arguments. During 698*523fa7a6SAndroid Build Coastguard Worker execution, Tensor operations are recorded in a fx.GraphModule, which is then 699*523fa7a6SAndroid Build Coastguard Worker returned. 700*523fa7a6SAndroid Build Coastguard Worker 701*523fa7a6SAndroid Build Coastguard Worker Args: 702*523fa7a6SAndroid Build Coastguard Worker f: A `nn.Module` or a Python function that implements an ML program. 703*523fa7a6SAndroid Build Coastguard Worker args: A tuple of arguments of any type to be used as inputs for the tracing run. 704*523fa7a6SAndroid Build Coastguard Worker 705*523fa7a6SAndroid Build Coastguard Worker Returns: 706*523fa7a6SAndroid Build Coastguard Worker EXIR contained in a fx.GraphModule 707*523fa7a6SAndroid Build Coastguard Worker """ 708*523fa7a6SAndroid Build Coastguard Worker trace_func = f 709*523fa7a6SAndroid Build Coastguard Worker guards = set() 710*523fa7a6SAndroid Build Coastguard Worker if TORCHDYNAMO_ENABLED: 711*523fa7a6SAndroid Build Coastguard Worker # Copying args is safer in case downstream implementations of trace_func mutate them 712*523fa7a6SAndroid Build Coastguard Worker trace_func, guards = dynamo_trace(trace_func, args, False) 713*523fa7a6SAndroid Build Coastguard Worker 714*523fa7a6SAndroid Build Coastguard Worker # Copying args is safer in case downstream implementations of trace_func mutate them 715*523fa7a6SAndroid Build Coastguard Worker trace_args, in_spec = pytree.tree_flatten(args) 716*523fa7a6SAndroid Build Coastguard Worker 717*523fa7a6SAndroid Build Coastguard Worker in_args = copy.deepcopy(tuple(trace_args)) 718*523fa7a6SAndroid Build Coastguard Worker gm, tree_out = flattened_dispatch_trace( 719*523fa7a6SAndroid Build Coastguard Worker trace_func, 720*523fa7a6SAndroid Build Coastguard Worker in_args, 721*523fa7a6SAndroid Build Coastguard Worker guards, 722*523fa7a6SAndroid Build Coastguard Worker in_spec, 723*523fa7a6SAndroid Build Coastguard Worker enable_functionalization=False, 724*523fa7a6SAndroid Build Coastguard Worker ) 725*523fa7a6SAndroid Build Coastguard Worker 726*523fa7a6SAndroid Build Coastguard Worker _, out_spec = pytree.tree_flatten(tree_out) 727*523fa7a6SAndroid Build Coastguard Worker 728*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `GraphModule` has no attribute `in_spec`. 729*523fa7a6SAndroid Build Coastguard Worker gm.in_spec = in_spec 730*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `GraphModule` has no attribute `out_spec`. 731*523fa7a6SAndroid Build Coastguard Worker gm.out_spec = out_spec 732*523fa7a6SAndroid Build Coastguard Worker 733*523fa7a6SAndroid Build Coastguard Worker # TODO (tmanlaibaatar) This is bit clowny, but our 734*523fa7a6SAndroid Build Coastguard Worker # dispatch_trace sometimes creates unused node that 735*523fa7a6SAndroid Build Coastguard Worker # breaks functionalization. it seems too much trouble 736*523fa7a6SAndroid Build Coastguard Worker # to fix it properly since dispatch_trace will be deprecated soon. 737*523fa7a6SAndroid Build Coastguard Worker # Basically dispatch_trace struggles on: 738*523fa7a6SAndroid Build Coastguard Worker # def f(x: torch.Tensor) -> torch.Tensor: 739*523fa7a6SAndroid Build Coastguard Worker # return torch.ones(6, dtype=x.dtype) 740*523fa7a6SAndroid Build Coastguard Worker changed = gm.graph.eliminate_dead_code() 741*523fa7a6SAndroid Build Coastguard Worker if changed: 742*523fa7a6SAndroid Build Coastguard Worker gm.recompile() 743*523fa7a6SAndroid Build Coastguard Worker 744*523fa7a6SAndroid Build Coastguard Worker in_args = copy.deepcopy(tuple(trace_args)) 745*523fa7a6SAndroid Build Coastguard Worker assert callable(gm) 746*523fa7a6SAndroid Build Coastguard Worker 747*523fa7a6SAndroid Build Coastguard Worker # This wrapper is used for preserving the stacktrace 748*523fa7a6SAndroid Build Coastguard Worker # during second round of tracing. 749*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 750*523fa7a6SAndroid Build Coastguard Worker def graph_with_interpreter(*args): 751*523fa7a6SAndroid Build Coastguard Worker try: 752*523fa7a6SAndroid Build Coastguard Worker args = fx_pytree.tree_flatten_spec(args, gm.in_spec) # type: ignore[assignment] 753*523fa7a6SAndroid Build Coastguard Worker except Exception: 754*523fa7a6SAndroid Build Coastguard Worker _, received_spec = pytree.tree_flatten(args) 755*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 756*523fa7a6SAndroid Build Coastguard Worker "Trying to flatten user inputs with exported input tree spec: \n" 757*523fa7a6SAndroid Build Coastguard Worker f"{gm.in_spec}\n" 758*523fa7a6SAndroid Build Coastguard Worker "but actually got inputs with tree spec of: \n" 759*523fa7a6SAndroid Build Coastguard Worker f"{received_spec}" 760*523fa7a6SAndroid Build Coastguard Worker ) 761*523fa7a6SAndroid Build Coastguard Worker with torch.fx.traceback.preserve_node_meta(): 762*523fa7a6SAndroid Build Coastguard Worker res = gm(*args) 763*523fa7a6SAndroid Build Coastguard Worker 764*523fa7a6SAndroid Build Coastguard Worker if gm.out_spec is not None: 765*523fa7a6SAndroid Build Coastguard Worker try: 766*523fa7a6SAndroid Build Coastguard Worker res = pytree.tree_unflatten(res, gm.out_spec) 767*523fa7a6SAndroid Build Coastguard Worker except Exception: 768*523fa7a6SAndroid Build Coastguard Worker _, received_spec = pytree.tree_flatten(res) 769*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 770*523fa7a6SAndroid Build Coastguard Worker "Trying to flatten user outputs with exported output tree spec: \n" 771*523fa7a6SAndroid Build Coastguard Worker f"{gm.out_spec}\n" 772*523fa7a6SAndroid Build Coastguard Worker "but actually got outputs with tree spec of: \n" 773*523fa7a6SAndroid Build Coastguard Worker f"{received_spec}" 774*523fa7a6SAndroid Build Coastguard Worker ) 775*523fa7a6SAndroid Build Coastguard Worker return res 776*523fa7a6SAndroid Build Coastguard Worker 777*523fa7a6SAndroid Build Coastguard Worker gm, tree_out = flattened_dispatch_trace( 778*523fa7a6SAndroid Build Coastguard Worker graph_with_interpreter, 779*523fa7a6SAndroid Build Coastguard Worker in_args, 780*523fa7a6SAndroid Build Coastguard Worker guards, 781*523fa7a6SAndroid Build Coastguard Worker in_spec, 782*523fa7a6SAndroid Build Coastguard Worker enable_functionalization=True, 783*523fa7a6SAndroid Build Coastguard Worker ) 784*523fa7a6SAndroid Build Coastguard Worker 785*523fa7a6SAndroid Build Coastguard Worker gm.in_spec = in_spec 786*523fa7a6SAndroid Build Coastguard Worker gm.out_spec = out_spec 787*523fa7a6SAndroid Build Coastguard Worker 788*523fa7a6SAndroid Build Coastguard Worker return gm 789