xref: /aosp_15_r20/external/pytorch/torch/fx/_symbolic_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import builtins
3import copy
4import contextlib
5import functools
6import inspect
7import math
8import os
9import warnings
10import collections
11from itertools import chain
12from types import CodeType, FunctionType, ModuleType
13from typing import (
14    Any,
15    Callable,
16    Dict,
17    List,
18    NamedTuple,
19    Optional,
20    Set,
21    Tuple,
22    Type,
23    Union,
24)
25
26import torch
27import torch.utils._pytree as pytree
28from torch._C import ScriptObject  # type: ignore[attr-defined]
29from torch._library.fake_class_registry import FakeScriptObject
30
31from ._compatibility import compatibility
32from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
33from .graph_module import GraphModule
34from ._lazy_graph_module import _make_graph_module
35from .node import Argument, base_types, map_aggregate
36from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
37
38HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
39
40# These need to run in global scope to handle nested calls correctly
41_orig_module_call: Callable = torch.nn.Module.__call__
42_orig_module_getattr: Callable = torch.nn.Module.__getattr__
43
44_proxyable_classes: Dict[Type, None] = {}
45
46_is_fx_tracing_flag = False
47
48
49def is_fx_tracing():
50    return _is_fx_tracing_flag
51
52@compatibility(is_backward_compatible=True)
53class ProxyableClassMeta(type):
54    """
55    ProxyableClassMeta allows you to make construction of a given Python class
56    symbolically traceable. For example::
57
58        import torch
59        import torch.fx
60
61        class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
62            def __init__(self, left, right):
63                self.left, self.right = left, right
64
65            def add(self, other):
66                l = self.left + other.left
67                r = self.right + other.right
68                return TensorPair(l, r)
69
70            def mul(self, other):
71                l = self.left * other.left
72                r = self.right * other.right
73                return TensorPair(l, r)
74
75        def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
76            s = x.add(TensorPair(y, y))
77            return s.mul(x)
78
79        x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
80        y = torch.randn(5, 3)
81        ref_out = use_tensor_pair_ctor(x, y)
82
83        traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
84        print(traced.code)
85        '''
86        def forward(self, x : __main___TensorPair, y : torch.Tensor):
87            tensor_pair = __main___TensorPair(y, y);  y = None
88            add = x.add(tensor_pair);  tensor_pair = None
89            mul = add.mul(x);  add = x = None
90            return mul
91        '''
92
93    From this example, we can see that construction of a class (``TensorPair``)
94    defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
95    tracing.
96    """
97
98    def __init__(cls, name, bases, attrs):
99        _proxyable_classes.setdefault(cls)
100        super().__init__(name, bases, attrs)
101
102    def __call__(cls, *args, **kwargs):
103        instance = cls.__new__(cls)  # type: ignore[call-overload]
104
105        if not is_fx_tracing():
106            cls.__init__(instance, *args, **kwargs)  # type: ignore[misc]
107            return instance
108
109        found_proxies = []
110
111        def check_proxy(a):
112            if isinstance(a, Proxy):
113                found_proxies.append(a)
114
115        map_aggregate(args, check_proxy)
116        map_aggregate(kwargs, check_proxy)
117
118        if len(found_proxies) != 0:
119            tracer = found_proxies[0].tracer
120            return tracer.create_proxy("call_function", cls, args, kwargs)
121        else:
122            cls.__init__(instance, *args, **kwargs)  # type: ignore[misc]
123            return instance
124
125
126def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
127    co = fn.__code__
128    co_flags = co.co_flags & ~HAS_VARSTUFF
129    co_args: tuple
130    if hasattr(co, "co_qualname"):
131        # Python-3.11+ code signature
132        co_args = (
133            nargs,
134            0,
135            0,
136            co.co_nlocals,
137            co.co_stacksize,
138            co_flags,
139            co.co_code,
140            co.co_consts,
141            co.co_names,
142            co.co_varnames,
143            co.co_filename,
144            co.co_name,
145            co.co_qualname,  # type: ignore[attr-defined]
146            co.co_firstlineno,
147            co.co_lnotab,
148            co.co_exceptiontable,  # type: ignore[attr-defined]
149            co.co_freevars,
150            co.co_cellvars,
151        )
152    elif hasattr(co, "co_posonlyargcount"):
153        co_args = (
154            nargs,
155            0,
156            0,
157            co.co_nlocals,
158            co.co_stacksize,
159            co_flags,
160            co.co_code,
161            co.co_consts,
162            co.co_names,
163            co.co_varnames,
164            co.co_filename,
165            co.co_name,
166            co.co_firstlineno,
167            co.co_lnotab,
168            co.co_freevars,
169            co.co_cellvars,
170        )
171    else:
172        co_args = (
173            nargs,
174            0,
175            co.co_nlocals,
176            co.co_stacksize,
177            co_flags,
178            co.co_code,
179            co.co_consts,
180            co.co_names,
181            co.co_varnames,
182            co.co_filename,
183            co.co_name,
184            co.co_firstlineno,
185            co.co_lnotab,
186            co.co_freevars,
187            co.co_cellvars,
188        )
189    new_code = CodeType(*co_args)  # type: ignore[arg-type]
190    return FunctionType(
191        new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
192    )
193
194    # we need to insert placeholder nodes for *args and **kwargs
195    # we can't call this function normally, otherwise it would try to unpack them
196    # instead, let's make python think that args and kwargs are normal variables
197
198
199@compatibility(is_backward_compatible=False)
200class PHBase:
201    """
202    Object representing an input placeholder to `concrete_args`
203    """
204
205    def __repr__(self):
206        return "PH"
207
208
209PH = PHBase()
210
211
212@compatibility(is_backward_compatible=False)
213class PHWithMeta(PHBase):
214    """
215    Object representing an input placeholder to `concrete_args`
216    """
217    def __init__(self, ph_key: Optional[str] = None):
218        super().__init__()
219
220        # Provide a hey for user to identify placeholder node during analysis
221        self.ph_key = ph_key
222
223
224def _transfer_attrs(fr, to):
225    for attr_name in dir(fr):
226        attr_val = getattr(fr, attr_name)
227        if (
228            not callable(attr_val)
229            and not attr_name.startswith("__")
230            and not hasattr(to, attr_name)
231        ):
232            setattr(to, attr_name, attr_val)
233
234
235@compatibility(is_backward_compatible=True)
236class Tracer(TracerBase):
237    # Reference: https://github.com/pytorch/pytorch/issues/54354
238    # The first line of this docstring overrides the one Sphinx generates for the
239    # documentation. We need it so that Sphinx doesn't leak `math`s path from the
240    # build environment (e.g. `<module 'math' from '/leaked/path').
241
242    """Tracer(autowrap_modules=(math,), autowrap_functions=())
243
244    ``Tracer`` is the class that implements the symbolic tracing functionality
245    of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
246    to ``Tracer().trace(m)``.
247
248    Tracer can be subclassed to override various behaviors of the tracing
249    process. The different behaviors that can be overridden are described
250    in the docstrings of the methods on this class.
251    """
252
253    # Not checking BC on this API because the default value for `autowrap_modules`
254    # includes the local filepath to the `math` module, which would jitter
255    # across machines.
256    @compatibility(is_backward_compatible=True)
257    def __init__(
258        self,
259        autowrap_modules: Tuple[ModuleType] = (math,),
260        autowrap_functions: Tuple[Callable, ...] = (),
261        param_shapes_constant: bool = False,
262    ) -> None:
263        # This method's signature is overridden by the first line of this class'
264        # docstring. If this method's signature is modified, the signature that
265        # overrides it also should be modified accordingly.
266
267        """
268        Construct a Tracer object.
269
270        Args:
271
272            autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
273                Python modules whose functions should be wrapped automatically
274                without needing to use fx.wrap(). Backward-compatibility for
275                this parameter is guaranteed.
276
277            autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
278                Python functions that should be wrapped automatically without
279                needing to use fx.wrap(). Backward compatibility for this
280                parameter is guaranteed.
281
282            param_shapes_constant (bool): When this flag is set,  calls to shape,
283                size and a few other shape like attributes of a module's parameter
284                will be evaluated directly, rather than returning a new Proxy value
285                for an attribute access. Backward compatibility for this parameter
286                is guaranteed.
287        """
288
289        super().__init__()
290
291        # Functions we will eagerly wrap when we see them while tracing
292        # this captures both `math.sqrt()` and `from math import sqrt` automatically
293        self._autowrap_function_ids: Set[int] = {
294            id(value)
295            for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
296            if not name.startswith("_") and callable(value)
297        }
298        self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
299
300        # Python modules to apply autowrap to at the start, in addition to
301        # modules we see while tracing
302        self._autowrap_search: List[ModuleType] = list(autowrap_modules)
303        self.param_shapes_constant = param_shapes_constant
304
305        self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
306        self.root_module_name: str = ""
307        # Maps the containing module's name to the operator name
308        self.scope = Scope("", None)
309        # Records the module call stack
310        self.module_stack = collections.OrderedDict()
311        # Mapping of node name to module scope
312        self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
313
314    _qualname_counter: Dict[str, int] = collections.defaultdict(int)
315
316    @compatibility(is_backward_compatible=True)
317    def get_fresh_qualname(self, prefix: str) -> str:
318        """
319        Gets a fresh name for a prefix and returns it. This function ensures
320        that it will not clash with an existing attribute on the graph.
321        """
322        # The idea here is that if the module doesn't have this prefix at all we
323        # should reset the counter to start from the beginning
324        # It's a ... little bit hacky (doesn't cover all cases) but the precise
325        # naming of the prefixes isn't a correctness issue, just a niceness
326        # issue
327        qualname = f"{prefix}0"
328        if not hasattr(self.root, qualname):
329            self._qualname_counter[prefix] = 0
330            return qualname
331
332        i = self._qualname_counter[prefix]
333        while True:
334            qualname = f"{prefix}{i}"
335            i += 1
336            if not hasattr(self.root, qualname):
337                break
338        self._qualname_counter[prefix] = i
339
340        return qualname
341
342    @compatibility(is_backward_compatible=True)
343    def create_arg(self, a: Any) -> "Argument":
344        """
345        A method to specify the behavior of tracing when preparing values to
346        be used as arguments to nodes in the ``Graph``.
347
348        By default, the behavior includes:
349
350        #. Iterate through collection types (e.g. tuple, list, dict) and recursively
351           call ``create_args`` on the elements.
352        #. Given a Proxy object, return a reference to the underlying IR ``Node``
353        #. Given a non-Proxy Tensor object, emit IR for various cases:
354
355            * For a Parameter, emit a ``get_attr`` node referring to that Parameter
356            * For a non-Parameter Tensor, store the Tensor away in a special
357              attribute referring to that attribute.
358
359        This method can be overridden to support more types.
360
361        Args:
362
363            a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
364
365
366        Returns:
367
368            The value ``a`` converted into the appropriate ``Argument``
369        """
370        # The base tracer is used to construct Graphs when there is no associated
371        # module hierarchy, so it can never create parameter references.
372        # The default tracer adds the ability to refer to parameters when
373        # tracing modules.
374        if isinstance(a, torch.nn.Parameter):
375            for n, p in self.root.named_parameters():
376                if a is p:
377                    return self.create_node("get_attr", n, (), {})
378            raise NameError("parameter is not a member of this module")
379        elif isinstance(a, torch.Tensor):
380            for n_, p_ in self.root.named_buffers():
381                if a is p_:
382                    return self.create_node("get_attr", n_, (), {})
383        elif isinstance(a, torch.nn.Module):
384            for n_, p_ in self.root.named_modules():
385                if a is p_:
386                    return self.create_node("get_attr", n_, (), {})
387        # For NamedTuple instances that appear literally as args, we emit
388        # a node to construct the NamedTuple and use that Node as the argument.
389        if isinstance(a, tuple) and hasattr(a, "_fields"):
390            args = tuple(self.create_arg(elem) for elem in a)
391            return self.create_node("call_function", a.__class__, args, {})
392
393        # Tensors do not have a reliable string repr() from which they can be
394        # constructed (and we probably don't want to rely on that, either), so
395        # for any constant Tensor values we encounter, first search for if they
396        # are an attribute of some module in the module hierarchy. If so, emit
397        # a get_attr to retrieve that tensor. Otherwise, we'll store away the
398        # tensor value into a special attribute on the Module s.t. we can
399        # retrieve it with a get_attr.
400        if isinstance(a, (torch.Tensor, ScriptObject, FakeScriptObject)):
401            qualname: Optional[str] = self.tensor_attrs.get(a)
402
403            # Tensor was not found in the Module hierarchy, stow it away in a
404            # special attribute and set the qualname to refer to that
405            if not qualname:
406                base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj"
407                qualname = self.get_fresh_qualname(base_name)
408                assert isinstance(qualname, str)
409                self.tensor_attrs[a] = qualname
410                setattr(self.root, qualname, a)
411
412            return self.create_node("get_attr", qualname, (), {})
413
414        if type(a) in _proxyable_classes:
415            # This is an instance of a proxyable class for which we did not
416            # witness its construction. Intern this as a constant attribute
417
418            # TODO: binary search
419            qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_")
420            assert isinstance(qualname, str)
421            setattr(self.root, qualname, a)
422
423            return self.create_node("get_attr", qualname, (), {})
424
425        return super().create_arg(a)
426
427    @compatibility(is_backward_compatible=True)
428    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
429        """
430        A method to specify whether a given ``nn.Module`` is a "leaf" module.
431
432        Leaf modules are the atomic units that appear in
433        the IR, referenced by ``call_module`` calls. By default,
434        Modules in the PyTorch standard library namespace (torch.nn)
435        are leaf modules. All other modules are traced through and
436        their constituent ops are recorded, unless specified otherwise
437        via this parameter.
438
439        Args:
440
441            m (Module): The module being queried about
442            module_qualified_name (str): The path to root of this module. For example,
443                if you have a module hierarchy where submodule ``foo`` contains
444                submodule ``bar``, which contains submodule ``baz``, that module will
445                appear with the qualified name ``foo.bar.baz`` here.
446        """
447        return (
448            (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
449            and not isinstance(m, torch.nn.Sequential)
450        )
451
452    @compatibility(is_backward_compatible=True)
453    def path_of_module(self, mod: torch.nn.Module) -> str:
454        """
455        Helper method to find the qualified name of ``mod`` in the Module hierarchy
456        of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
457        a submodule named ``bar``, passing ``bar`` into this function will return
458        the string "foo.bar".
459
460        Args:
461
462            mod (str): The ``Module`` to retrieve the qualified name for.
463        """
464        # Prefer the O(1) algorithm
465        if self.submodule_paths:
466            path = self.submodule_paths.get(mod)
467            if path is None:
468                raise NameError("module is not installed as a submodule")
469            assert isinstance(path, str)
470            return path
471        # O(N^2) fallback in the case that we didn't store the submodule
472        # paths.
473        else:
474            for n, p in self.root.named_modules():
475                if mod is p:
476                    return n
477            raise NameError("module is not installed as a submodule")
478
479    @compatibility(is_backward_compatible=True)
480    def call_module(
481        self,
482        m: torch.nn.Module,
483        forward: Callable[..., Any],
484        args: Tuple[Any, ...],
485        kwargs: Dict[str, Any],
486    ) -> Any:
487        """
488        Method that specifies the behavior of this ``Tracer`` when it encounters
489        a call to an ``nn.Module`` instance.
490
491        By default, the behavior is to check if the called module is a leaf module
492        via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
493        ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
494        the operations in its ``forward`` function.
495
496        This method can be overridden to--for example--create nested traced
497        GraphModules, or any other behavior you would want while tracing across
498        ``Module`` boundaries.
499
500        Args:
501
502            m (Module): The module for which a call is being emitted
503            forward (Callable): The forward() method of the ``Module`` to be invoked
504            args (Tuple): args of the module callsite
505            kwargs (Dict): kwargs of the module callsite
506
507        Return:
508
509            The return value from the Module call. In the case that a ``call_module``
510            node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
511            value was returned from the ``Module`` invocation.
512        """
513        module_qualified_name = self.path_of_module(m)
514        with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
515            # module_stack is an ordered dict so writing then deleting the
516            # entry is equivalent to push/pop on a list
517            self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type)
518            if not self.is_leaf_module(m, module_qualified_name):
519                ret_val = forward(*args, **kwargs)
520            else:
521                ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
522            key, _ = self.module_stack.popitem(last=True)
523            assert key == _scope.module_path, f" Unexpected key {key}"
524
525        return ret_val
526
527    @compatibility(is_backward_compatible=False)
528    def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
529        """
530        Method that specifies the behavior of this ``Tracer`` when we call getattr
531        on a call to an ``nn.Module`` instance.
532
533        By default, the behavior is to return a proxy value for the attribute. It
534        also stores the proxy value in the ``parameter_proxy_cache``, so that future
535        calls will reuse the proxy rather than creating a new one.
536
537        This method can be overridden to --for example-- not return proxies when
538        querying parameters.
539
540        Args:
541
542            attr (str): The name of the attribute being queried
543            attr_val (Any): The value of the attribute
544            parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
545
546        Return:
547
548            The return value from the getattr call.
549        """
550        def maybe_get_proxy_for_attr(
551            attr_val, collection_to_search, parameter_proxy_cache
552        ):
553            for n, p in collection_to_search:
554                if attr_val is p:
555                    if n not in parameter_proxy_cache:
556                        kwargs = {}
557                        if (
558                            "proxy_factory_fn"
559                            in inspect.signature(self.create_proxy).parameters
560                        ):
561                            kwargs["proxy_factory_fn"] = (
562                                None
563                                if not self.param_shapes_constant
564                                else lambda node: ParameterProxy(
565                                    self, node, n, attr_val
566                                )
567                            )
568                        val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs)  # type: ignore[arg-type]
569                        parameter_proxy_cache[n] = val_proxy
570                    return parameter_proxy_cache[n]
571            return None
572
573        if isinstance(attr_val, torch.nn.Parameter):
574            maybe_parameter_proxy = maybe_get_proxy_for_attr(
575                attr_val, self.root.named_parameters(), parameter_proxy_cache
576            )
577            if maybe_parameter_proxy is not None:
578                return maybe_parameter_proxy
579
580        if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
581            maybe_buffer_proxy = maybe_get_proxy_for_attr(
582                attr_val, self.root.named_buffers(), parameter_proxy_cache
583            )
584            if maybe_buffer_proxy is not None:
585                return maybe_buffer_proxy
586
587        return attr_val
588
589    # This method will be refactored
590    @compatibility(is_backward_compatible=False)
591    def create_args_for_root(self, root_fn, is_module, concrete_args=None):
592        """
593        Create ``placeholder`` nodes corresponding to the signature of the ``root``
594        Module. This method introspects root's signature and emits those
595        nodes accordingly, also supporting ``*args`` and ``**kwargs``.
596        """
597        # In some cases, a function or method has been decorated with a wrapper
598        # defined via ``functools.wraps``. In this case, the outer code object
599        # will likely not contain the actual parameters we care about, so unwrap
600        # the function to get to the innermost callable.
601        fn_for_analysis = inspect.unwrap(root_fn)
602        co = fn_for_analysis.__code__
603        total_args = co.co_argcount + co.co_kwonlyargcount
604        orig_args = list(co.co_varnames)
605        names_iter = iter(co.co_varnames)
606        args: List[Any] = []
607        skip_arg_idx = 0
608        if is_module:
609            if total_args == 0:
610                raise RuntimeError(
611                    "``self`` argument cannot be part of *args expansion!"
612                )
613            skip_arg_idx = 1
614            next(names_iter)  # skip self
615            args.append(self.root)
616
617        sig = inspect.signature(fn_for_analysis)
618
619
620        # This covers the very specific case where we are passing in flat
621        # concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
622        # In this case, just take the concrete_args and pass them through.
623        name_idx = 0
624        if isinstance(concrete_args, tuple) and \
625                len(concrete_args) > 0 and \
626                (co.co_flags & HAS_VARSTUFF) and \
627                total_args == 1:
628            for concrete_arg in concrete_args:
629                out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
630                if isinstance(concrete_arg, PHBase):
631                    if concrete_arg != PH:
632                        # Transfer attrs in the case where you're using a placeholder other
633                        # than the singleton PH (PH has no attributes to transfer).
634                        # Proxies were created out of the placeholders.
635                        # Transfer any metadata (put on the placeholders in the form of
636                        # attributes set by the user) from the placeholder to the
637                        # underlying nodes (the proxy is unwrapped by the user, but
638                        # the metadata should hold).
639                        _transfer_attrs(fr=concrete_arg, to=out.node)
640                args.append(out)
641                name_idx += 1
642            return root_fn, args
643
644        arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
645        if isinstance(concrete_args, tuple):
646            if len(arg_names) != len(concrete_args):
647                raise RuntimeError(
648                    f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
649                )
650            concrete_args = dict(zip(arg_names, concrete_args))
651
652        def proxy_placeholder(name):
653            return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis)
654
655        args.extend(proxy_placeholder(names) for names in arg_names)
656
657        if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
658            # TODO: type annotations for *args and **kwargs
659            if co.co_flags & inspect.CO_VARARGS:
660                args.append(proxy_placeholder("*" + next(names_iter)))
661            if co.co_flags & inspect.CO_VARKEYWORDS:
662                args.append(proxy_placeholder("**" + next(names_iter)))
663            root_fn = _patch_function(root_fn, len(args))
664
665        flat_args, in_spec = pytree.tree_flatten(tuple(args))
666        if not all(child.is_leaf() for child in in_spec.children_specs):
667            # In the case that we have pytree-flattened inputs in
668            # `concrete_args`, generate a flattening wrapper around the
669            # original root function and return that.
670            self.graph._codegen = _PyTreeCodeGen(
671                _PyTreeInfo(orig_args[:total_args], in_spec, None)
672            )
673
674            def flatten_fn(*args):
675                tree_args = pytree.tree_unflatten(list(args), in_spec)
676                tree_out = root_fn(*tree_args)
677                out_args, out_spec = pytree.tree_flatten(tree_out)
678                assert isinstance(self.graph._codegen, _PyTreeCodeGen)
679                self.graph._codegen.pytree_info = (
680                    self.graph._codegen.pytree_info._replace(out_spec=out_spec)
681                )
682                return out_args
683
684            return flatten_fn, flat_args
685        return root_fn, args
686
687    @compatibility(is_backward_compatible=True)
688    def trace(
689        self,
690        root: Union[torch.nn.Module, Callable[..., Any]],
691        concrete_args: Optional[Dict[str, Any]] = None,
692    ) -> Graph:
693        """
694        Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
695        can either be an ``nn.Module`` instance or a Python callable.
696
697        Note that after this call, ``self.root`` may be different from the ``root`` passed
698        in here. For example, when a free function is passed to ``trace()``, we will
699        create an ``nn.Module`` instance to use as the root and add embedded constants
700        to.
701
702
703        Args:
704
705            root (Union[Module, Callable]): Either a ``Module`` or a function to be
706                traced through. Backwards-compatibility for this parameter is
707                guaranteed.
708            concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
709                not be treated as Proxies. This parameter is experimental and
710                its backwards-compatibility is *NOT* guaranteed.
711
712        Returns:
713
714            A ``Graph`` representing the semantics of the passed-in ``root``.
715        """
716        global _is_fx_tracing_flag
717        old_is_fx_tracing_flag = _is_fx_tracing_flag
718        _is_fx_tracing_flag = True
719        try:
720            if isinstance(root, torch.nn.Module):
721
722                # do real recompilation for _LazyGraphModule before retracing since the trace
723                # method can not trace the _lazy_forward method. Got error:
724                #   https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
725                # without this.
726                from torch.fx._lazy_graph_module import _LazyGraphModule
727                _LazyGraphModule.force_recompile(root)
728
729                self.root = root
730
731                assert hasattr(
732                    type(root), self.traced_func_name
733                ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
734
735                fn = getattr(type(root), self.traced_func_name)
736                self.root_module_name = root._get_name()
737                self.submodule_paths = {mod: name for name, mod in root.named_modules()}
738            else:
739                self.root = torch.nn.Module()
740                fn = root
741
742            tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None)
743            self.graph = Graph(tracer_cls=tracer_cls)
744            if hasattr(fn, '__code__'):
745                code = fn.__code__
746                self.graph._co_fields = {
747                    'co_name': code.co_name,
748                    'co_filename': code.co_filename,
749                    'co_firstlineno': code.co_firstlineno,
750                }
751
752            # When we encounter a Tensor value that's not a parameter, we look if it
753            # is some other attribute on the model. Construct a dict mapping Tensor
754            # values to the qualified name here for efficiency. This is used downstream
755            # in create_arg
756            self.tensor_attrs: Dict[
757                Union[
758                    torch.Tensor,
759                    ScriptObject,
760                    FakeScriptObject
761                ], str
762            ] = {}
763
764            def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
765                for k, v in m.__dict__.items():
766                    if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)):
767                        self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
768                for k, v in m.named_children():
769                    collect_tensor_attrs(v, prefix_atoms + [k])
770
771            collect_tensor_attrs(self.root, [])
772
773            assert isinstance(fn, FunctionType)
774
775            fn_globals = fn.__globals__  # run before it gets patched
776            fn, args = self.create_args_for_root(
777                fn, isinstance(root, torch.nn.Module), concrete_args
778            )
779
780            parameter_proxy_cache: Dict[
781                str, Proxy
782            ] = {}  # Reduce number of get_attr calls
783
784            # Method dispatch on parameters is not recorded unless it's directly used.
785            # Thus, we need to insert a proxy when __getattr__ requests a parameter.
786            @functools.wraps(_orig_module_getattr)
787            def module_getattr_wrapper(mod, attr):
788                attr_val = _orig_module_getattr(mod, attr)
789                return self.getattr(attr, attr_val, parameter_proxy_cache)
790
791            @functools.wraps(_orig_module_call)
792            def module_call_wrapper(mod, *args, **kwargs):
793                def forward(*args, **kwargs):
794                    return _orig_module_call(mod, *args, **kwargs)
795
796                _autowrap_check(
797                    patcher,  # type: ignore[has-type]
798                    getattr(getattr(mod, "forward", mod), "__globals__", {}),
799                    self._autowrap_function_ids,
800                )
801                return self.call_module(mod, forward, args, kwargs)
802
803            with _new_patcher() as patcher:
804                # allow duplicate patches to support the case of nested calls
805                patcher.patch_method(
806                    torch.nn.Module,
807                    "__getattr__",
808                    module_getattr_wrapper,
809                    deduplicate=False,
810                )
811                patcher.patch_method(
812                    torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
813                )
814                _patch_wrapped_functions(patcher)
815                _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
816                for module in self._autowrap_search:
817                    _autowrap_check(
818                        patcher, module.__dict__, self._autowrap_function_ids
819                    )
820                self.create_node(
821                    "output",
822                    "output",
823                    (self.create_arg(fn(*args)),),
824                    {},
825                    type_expr=fn.__annotations__.get("return", None),
826                )
827
828            self.submodule_paths = None
829        finally:
830            _is_fx_tracing_flag = old_is_fx_tracing_flag
831        return self.graph
832
833    def __deepcopy__(self, memo):
834        # _autowrap_search contains modules, which cannot be deepcopied.
835        new_tracer = Tracer.__new__(Tracer)
836
837        for k, v in self.__dict__.items():
838            if k in {'_autowrap_search'}:
839                new_obj = copy.copy(v)
840            else:
841                new_obj = copy.deepcopy(v, memo)
842
843            new_tracer.__dict__[k] = new_obj
844
845        return new_tracer
846
847    def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis):
848        if concrete_args is not None and name in concrete_args:
849            cnt = 0
850
851            def replace_ph(x):
852                nonlocal cnt
853                cnt += 1
854                param = sig.parameters[name]
855                default = (
856                    ()
857                    if param.default is inspect.Parameter.empty
858                    else (param.default,)
859                )
860                out = self.create_proxy(
861                    "placeholder", f"{name}_{str(cnt)}", default, {}
862                )
863                if isinstance(x, PHBase):
864                    if x != PH:
865                        # Transfer attrs in the case where you're using a placeholder other
866                        # than the singleton PH (PH has no attributes to transfer).
867                        # Proxies were created out of the placeholders.
868                        # Transfer any metadata (put on the placeholders in the form of
869                        # attributes set by the user) from the placeholder to the
870                        # underlying nodes (the proxy is unwrapped by the user, but
871                        # the metadata should hold).
872                        _transfer_attrs(fr=x, to=out.node)
873
874                    return out
875                # Union[int, bool] == bool in Python <= 3.6
876                if (
877                    type(x) == bool
878                    or type(x) in base_types
879                    and type(x) != torch.Tensor
880                ):
881                    torch._assert(
882                        out == x,
883                        f"{name} has been specialized to have value {x} but got another value",
884                    )
885                elif x is None:
886                    args = (
887                        out,
888                        f"{name} has been specialized to have value None but got another value",
889                    )
890                    self.create_proxy("call_function", _assert_is_none, args, {})
891                else:
892                    warnings.warn(
893                        f"Was not able to add assertion to guarantee correct input {name} to "
894                        f"specialized function. It is up to the user to make sure that your inputs match the "
895                        f"inputs you specialized the function with."
896                    )
897
898                return x
899
900            return pytree.tree_map(replace_ph, concrete_args[name])
901        if name[0] == "*":
902            default = ()
903        else:
904            param = sig.parameters[name]
905            default = () if param.default is inspect.Parameter.empty else (param.default,)  # type: ignore[assignment]
906        return self.create_proxy(
907            "placeholder",
908            name,
909            default,
910            {},
911            type_expr=fn_for_analysis.__annotations__.get(name, None)
912        )
913
914
915# Dictionary of (id(globals dict), function name) => globals_dict to patch for
916# the purposes of the wrap() API.
917# We key by the globals dict id and function name to ensure we're wrapping a given
918# function only once.
919_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {}
920
921# List of methods on classes to wrap (class type, function name)
922# this currently only works for Tensor.* methods that aren't traced properly
923_wrapped_methods_to_patch: List[Tuple[type, str]] = []
924
925if os.environ.get("FX_PATCH_GETITEM") == "1":
926    # This change is needed to trace models like PositionalEmbedding from BERT:
927    # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
928    # but causes issues in quantization documented here:
929    # https://github.com/pytorch/pytorch/issues/50710
930    # once that is fixed we can make this the default behavior.
931    _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
932
933
934def _find_proxy(*objects_to_search):
935    """
936    Recursively search a data structure for a Proxy() and return it,
937    return None if not found.
938    """
939    proxy = None
940
941    def find_proxy(x):
942        nonlocal proxy
943        if isinstance(x, Proxy):
944            proxy = x
945
946    map_aggregate(objects_to_search, find_proxy)
947    return proxy
948
949
950def _create_wrapped_func(orig_fn):
951    @functools.wraps(orig_fn)
952    def wrapped(*args, **kwargs):
953        """
954        Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
955        a Proxy object. If there is one, emit a ``call_function`` node to preserve the
956        call to this leaf function directly. Otherwise, just return the results of
957        this function call, as this function is not being traced.
958        """
959        proxy = _find_proxy(args, kwargs)
960        if proxy is not None:
961            return_proxy = proxy.tracer.create_proxy(
962                "call_function", orig_fn, args, kwargs
963            )
964            return_proxy.node.meta["is_wrapped"] = True
965            return return_proxy
966        return orig_fn(*args, **kwargs)
967
968    return wrapped
969
970
971def _create_wrapped_method(cls, name):
972    orig_fn = getattr(cls, name)
973
974    @functools.wraps(orig_fn)
975    def wrapped(*args, **kwargs):
976        """
977        Search the args and kwargs for a Proxy object. If there is one,
978        emit a ``call_method`` node to preserve the call to this method
979        directly. Otherwise, just return the results of this function
980        call, as this function is not being traced.
981        """
982        proxy = _find_proxy(args, kwargs)
983        if proxy is not None:
984            return proxy.tracer.create_proxy("call_method", name, args, kwargs)
985        return orig_fn(*args, **kwargs)
986
987    return wrapped
988
989
990class _PatchedFn(NamedTuple):
991    frame_dict: Any
992    fn_name: str
993    orig_fn: Any
994    new_fn: Any
995
996    def revert(self):
997        raise NotImplementedError
998
999    def patch(self):
1000        raise NotImplementedError
1001
1002
1003class _PatchedFnSetItem(_PatchedFn):
1004    def revert(self):
1005        self.frame_dict[self.fn_name] = self.orig_fn
1006
1007    def patch(self):
1008        self.frame_dict[self.fn_name] = self.new_fn
1009
1010class _PatchedFnDel(_PatchedFn):
1011    def revert(self):
1012        del self.frame_dict[self.fn_name]
1013
1014    def patch(self):
1015        self.frame_dict[self.fn_name] = self.new_fn
1016
1017
1018class _PatchedFnSetAttr(_PatchedFn):
1019    def revert(self):
1020        setattr(self.frame_dict, self.fn_name, self.orig_fn)
1021
1022    def patch(self):
1023        setattr(self.frame_dict, self.fn_name, self.new_fn)
1024
1025class _Patcher:
1026    def __init__(self) -> None:
1027        super().__init__()
1028        self.patches_made: List[_PatchedFn] = []
1029        self.visited: Set[int] = set()
1030
1031    def patch(
1032        self,
1033        frame_dict: Dict[str, Any],
1034        name: str,
1035        new_fn: Callable,
1036        deduplicate: bool = True,
1037    ):
1038        """
1039        Replace frame_dict[name] with new_fn until we exit the context manager.
1040        """
1041        new_fn.__fx_already_patched = deduplicate  # type: ignore[attr-defined]
1042        if name not in frame_dict and hasattr(builtins, name):
1043            self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn))
1044            self.patches_made[-1].patch()
1045        elif getattr(frame_dict[name], "__fx_already_patched", False):
1046            return  # already patched, no need to do it again
1047        else:
1048            self.patches_made.append(
1049                _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn)
1050            )
1051            self.patches_made[-1].patch()
1052
1053    def patch_method(
1054        self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
1055    ):
1056        """
1057        Replace object_or_dict.name with new_fn until we exit the context manager.
1058        """
1059        new_fn.__fx_already_patched = deduplicate  # type: ignore[attr-defined]
1060        orig_fn = getattr(cls, name)
1061        if getattr(orig_fn, "__fx_already_patched", False):
1062            return  # already patched, no need to do it again
1063        self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn))
1064        self.patches_made[-1].patch()
1065
1066    def visit_once(self, thing: Any):
1067        """Return True on the first call to with thing, otherwise false"""
1068        idx = id(thing)
1069        if idx in self.visited:
1070            return False
1071        self.visited.add(idx)
1072        return True
1073
1074    def revert_all_patches(self):
1075        """
1076        Remove all the stored patcheds. It doesn't modify patches_made.
1077        """
1078        for patch in self.patches_made:
1079            patch.revert()
1080        return self.patches_made
1081
1082    def reapply_all_patches(self):
1083        """
1084        Patch all the stored patcheds. It doesn't modify patches_made.
1085        """
1086        for patch in self.patches_made:
1087            patch.patch()
1088        return self.patches_made
1089
1090    def __enter__(self):
1091        return self
1092
1093    def __exit__(self, exc_type, exc_val, exc_tb):
1094        """
1095        Undo all the changes made via self.patch() and self.patch_method()
1096        """
1097        while self.patches_made:
1098            # unpatch in reverse order to handle duplicates correctly
1099            self.patches_made.pop().revert()
1100        self.visited.clear()
1101
1102
1103CURRENT_PATCHER: Optional[_Patcher] = None
1104
1105@contextlib.contextmanager
1106def _new_patcher():
1107    global CURRENT_PATCHER
1108    prior_patcher = CURRENT_PATCHER
1109    try:
1110        CURRENT_PATCHER = _Patcher()
1111        yield CURRENT_PATCHER
1112    finally:
1113        # Clear all the patches made by when using current patcher.
1114        assert CURRENT_PATCHER is not None
1115        CURRENT_PATCHER.revert_all_patches()
1116        CURRENT_PATCHER = prior_patcher
1117
1118
1119@contextlib.contextmanager
1120def _maybe_revert_all_patches():
1121    current_patcher = CURRENT_PATCHER
1122    patches_made = None
1123    patches_removed = None
1124    try:
1125        if current_patcher is not None:
1126            patches_removed = current_patcher.revert_all_patches()
1127        yield
1128    finally:
1129        if current_patcher is not None:
1130            patches_made = current_patcher.reapply_all_patches()
1131        assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches"
1132
1133def _patch_wrapped_functions(patcher: _Patcher):
1134    """
1135    Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
1136    the listed global functions in the `_create_wrapped_func` wrapper.
1137    """
1138    for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items():
1139        if name not in frame_dict and hasattr(builtins, name):
1140            orig_fn = getattr(builtins, name)
1141        else:
1142            orig_fn = frame_dict[name]
1143        patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
1144
1145    for cls, name in _wrapped_methods_to_patch:
1146        patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
1147
1148
1149def _autowrap_check(
1150    patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int]
1151):
1152    """
1153    Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
1154    This method searches a scope for them and patches them if found.
1155    """
1156    if patcher.visit_once(frame_dict):
1157        for name, value in frame_dict.items():
1158            if (
1159                not name.startswith("_")
1160                and callable(value)
1161                and id(value) in function_ids
1162            ):
1163                patcher.patch(frame_dict, name, _create_wrapped_func(value))
1164
1165
1166@compatibility(is_backward_compatible=True)
1167def wrap(fn_or_name: Union[str, Callable]):
1168    """
1169    This function can be called at module-level scope to register fn_or_name as a "leaf function".
1170    A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
1171    traced through::
1172
1173        # foo/bar/baz.py
1174        def my_custom_function(x, y):
1175            return x * x + y * y
1176
1177        torch.fx.wrap('my_custom_function')
1178
1179        def fn_to_be_traced(x, y):
1180            # When symbolic tracing, the below call to my_custom_function will be inserted into
1181            # the graph rather than tracing it.
1182            return my_custom_function(x, y)
1183
1184    This function can also equivalently be used as a decorator::
1185
1186        # foo/bar/baz.py
1187        @torch.fx.wrap
1188        def my_custom_function(x, y):
1189            return x * x + y * y
1190
1191    A wrapped function can be thought of a "leaf function", analogous to the concept of
1192    "leaf modules", that is, they are functions that are left as calls in the FX trace
1193    rather than traced through.
1194
1195    Args:
1196
1197        fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
1198            graph when it's called
1199    """
1200    if not callable(fn_or_name) and not isinstance(fn_or_name, str):
1201        raise RuntimeError(
1202            "Unsupported type for global function! Must be either a callable or "
1203            "string name"
1204        )
1205
1206    if callable(fn_or_name):
1207        assert not isinstance(fn_or_name, str)  # to make mypy happy
1208        fn_name = fn_or_name.__name__
1209    else:
1210        assert isinstance(
1211            fn_or_name, str
1212        ), "fn_or_name must be a global function or string name"
1213        fn_name = fn_or_name
1214
1215    currentframe = inspect.currentframe()
1216    assert currentframe is not None
1217    f = currentframe.f_back
1218    assert f is not None
1219    if f.f_code.co_name != "<module>":
1220        raise NotImplementedError("wrap must be called at the top level of a module")
1221
1222    # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
1223    # semantics would be slightly different, but would add support `from x import wrapped_function`
1224    _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals
1225    return fn_or_name
1226
1227
1228@compatibility(is_backward_compatible=True)
1229def symbolic_trace(
1230    root: Union[torch.nn.Module, Callable[..., Any]],
1231    concrete_args: Optional[Dict[str, Any]] = None,
1232) -> GraphModule:
1233    """
1234    Symbolic tracing API
1235
1236    Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
1237    constructed by recording operations seen while tracing through ``root``.
1238
1239    ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
1240
1241    For example::
1242
1243        def f(a, b):
1244            if b == True:
1245                return a
1246            else:
1247                return a*2
1248
1249    FX can typically not trace through this due to the presence of control
1250    flow. However, we can use `concrete_args` to specialize on the value of
1251    `b` to trace through this::
1252
1253        f = fx.symbolic_trace(f, concrete_args={'b': False})
1254        assert f(3, False)  == 6
1255
1256    Note that although you can still pass in different values of `b`, they will be ignored.
1257
1258    We can also use `concrete_args` to eliminate data-structure handling from
1259    our function. This will use pytrees to flatten your input. To avoid
1260    overspecializing, pass in `fx.PH` for values that shouldn't be
1261    specialized. For example::
1262
1263        def f(x):
1264            out = 0
1265            for v in x.values():
1266                out += v
1267            return out
1268        f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
1269        assert f({'a': 1, 'b': 2, 'c': 4}) == 7
1270
1271
1272    Args:
1273        root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
1274            into a Graph representation.
1275        concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
1276
1277    Returns:
1278        GraphModule: a Module created from the recorded operations from ``root``.
1279    """
1280    tracer = Tracer()
1281    graph = tracer.trace(root, concrete_args)
1282    name = (
1283        root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
1284    )
1285    return _make_graph_module(tracer.root, graph, name)
1286
1287
1288@wrap
1289def _assert_is_none(value, msg):
1290    assert value is None, msg
1291