xref: /aosp_15_r20/external/pytorch/torch/fx/node.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Nodes represent a definition of a value in our graph of operators.
2*da0073e9SAndroid Build Coastguard Workerfrom typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
3*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility
4*da0073e9SAndroid Build Coastguard Workerfrom .immutable_collections import immutable_dict, immutable_list
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerimport builtins
7*da0073e9SAndroid Build Coastguard Workerimport types
8*da0073e9SAndroid Build Coastguard Workerimport inspect
9*da0073e9SAndroid Build Coastguard Workerimport warnings
10*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
11*da0073e9SAndroid Build Coastguard Workerfrom .._ops import ops as _ops
12*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _NodeBase
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
15*da0073e9SAndroid Build Coastguard Worker    from .graph import Graph
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard WorkerBaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
20*da0073e9SAndroid Build Coastguard Worker                          torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload,
21*da0073e9SAndroid Build Coastguard Worker                          torch.SymInt, torch.SymBool, torch.SymFloat]
22*da0073e9SAndroid Build Coastguard Workerbase_types = BaseArgumentTypes.__args__  # type: ignore[attr-defined]
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard WorkerTarget = Union[Callable[..., Any], str]
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard WorkerArgument = Optional[Union[
27*da0073e9SAndroid Build Coastguard Worker    Tuple[Any, ...],  # actually Argument, but mypy can't represent recursive types
28*da0073e9SAndroid Build Coastguard Worker    List[Any],  # actually Argument
29*da0073e9SAndroid Build Coastguard Worker    Dict[str, Any],  # actually Argument
30*da0073e9SAndroid Build Coastguard Worker    slice,  # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
31*da0073e9SAndroid Build Coastguard Worker    range,
32*da0073e9SAndroid Build Coastguard Worker    'Node',
33*da0073e9SAndroid Build Coastguard Worker    BaseArgumentTypes
34*da0073e9SAndroid Build Coastguard Worker]]
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'])
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
39*da0073e9SAndroid Build Coastguard Worker    torch._C._set_grad_enabled,
40*da0073e9SAndroid Build Coastguard Worker    torch.amp._enter_autocast,
41*da0073e9SAndroid Build Coastguard Worker    torch.amp._exit_autocast,
42*da0073e9SAndroid Build Coastguard Worker}
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
45*da0073e9SAndroid Build Coastguard Worker# or add logic to correctly mark all inplace ops as side effectful.
46*da0073e9SAndroid Build Coastguard Worker_side_effectful_functions: Set[Callable] = {
47*da0073e9SAndroid Build Coastguard Worker    torch._assert,
48*da0073e9SAndroid Build Coastguard Worker    torch._assert_async,
49*da0073e9SAndroid Build Coastguard Worker    _ops.aten._assert_async.msg,
50*da0073e9SAndroid Build Coastguard Worker    _ops.aten._assert_scalar.default,
51*da0073e9SAndroid Build Coastguard Worker    _ops.aten.sym_constrain_range.default,
52*da0073e9SAndroid Build Coastguard Worker    _ops.aten.sym_constrain_range_for_size.default,
53*da0073e9SAndroid Build Coastguard Worker    _ops.profiler._record_function_enter,
54*da0073e9SAndroid Build Coastguard Worker    _ops.profiler._record_function_enter_new,
55*da0073e9SAndroid Build Coastguard Worker    _ops.profiler._record_function_exit,
56*da0073e9SAndroid Build Coastguard Worker    _ops.inductor.accumulate_grad_.default,
57*da0073e9SAndroid Build Coastguard Worker} | _side_effectful_need_to_be_preserved_pre_dispatch
58*da0073e9SAndroid Build Coastguard Workerif hasattr(_ops.inductor, "resize_storage_bytes_"):
59*da0073e9SAndroid Build Coastguard Worker    _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
63*da0073e9SAndroid Build Coastguard Workerdef has_side_effect(fn: Callable) -> Callable:
64*da0073e9SAndroid Build Coastguard Worker    _side_effectful_functions.add(fn)
65*da0073e9SAndroid Build Coastguard Worker    return fn
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker# this is fixed on master, WAR for 1.5
69*da0073e9SAndroid Build Coastguard Workerdef _find_module_of_method(orig_method: Callable[..., Any]) -> str:
70*da0073e9SAndroid Build Coastguard Worker    name = orig_method.__name__
71*da0073e9SAndroid Build Coastguard Worker    module = orig_method.__module__
72*da0073e9SAndroid Build Coastguard Worker    if module is not None:
73*da0073e9SAndroid Build Coastguard Worker        return module
74*da0073e9SAndroid Build Coastguard Worker    for guess in [torch, torch.nn.functional]:
75*da0073e9SAndroid Build Coastguard Worker        if getattr(guess, name, None) is orig_method:
76*da0073e9SAndroid Build Coastguard Worker            return guess.__name__
77*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(f'cannot find module for {orig_method}')
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker# Borrowed from CPython typing module
80*da0073e9SAndroid Build Coastguard Worker# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
81*da0073e9SAndroid Build Coastguard Workerdef _type_repr(obj: object) -> str:
82*da0073e9SAndroid Build Coastguard Worker    """Return the repr() of an object, special-casing types (internal helper).
83*da0073e9SAndroid Build Coastguard Worker    If obj is a type, we return a shorter version than the default
84*da0073e9SAndroid Build Coastguard Worker    type.__repr__, based on the module and qualified name, which is
85*da0073e9SAndroid Build Coastguard Worker    typically enough to uniquely identify a type.  For everything
86*da0073e9SAndroid Build Coastguard Worker    else, we fall back on repr(obj).
87*da0073e9SAndroid Build Coastguard Worker    """
88*da0073e9SAndroid Build Coastguard Worker    if isinstance(obj, type):
89*da0073e9SAndroid Build Coastguard Worker        if obj.__module__ == 'builtins':
90*da0073e9SAndroid Build Coastguard Worker            return obj.__qualname__
91*da0073e9SAndroid Build Coastguard Worker        return f'{obj.__module__}.{obj.__qualname__}'
92*da0073e9SAndroid Build Coastguard Worker    if obj is ...:
93*da0073e9SAndroid Build Coastguard Worker        return '...'
94*da0073e9SAndroid Build Coastguard Worker    if isinstance(obj, types.FunctionType):
95*da0073e9SAndroid Build Coastguard Worker        return obj.__name__
96*da0073e9SAndroid Build Coastguard Worker    return repr(obj)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Workerdef _get_qualified_name(func: Callable[..., Any]) -> str:
99*da0073e9SAndroid Build Coastguard Worker    # things like getattr just appear in builtins
100*da0073e9SAndroid Build Coastguard Worker    if getattr(builtins, func.__name__, None) is func:
101*da0073e9SAndroid Build Coastguard Worker        return func.__name__
102*da0073e9SAndroid Build Coastguard Worker    # torch.Tensor.{fn}
103*da0073e9SAndroid Build Coastguard Worker    if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType))
104*da0073e9SAndroid Build Coastguard Worker       and func is getattr(torch.Tensor, func.__name__, None)):
105*da0073e9SAndroid Build Coastguard Worker        return f"torch.Tensor.{func.__name__}"
106*da0073e9SAndroid Build Coastguard Worker    name = func.__name__
107*da0073e9SAndroid Build Coastguard Worker    if name == "<lambda>":
108*da0073e9SAndroid Build Coastguard Worker        # For lambdas, try to get their defining name in the module
109*da0073e9SAndroid Build Coastguard Worker        try:
110*da0073e9SAndroid Build Coastguard Worker            name = inspect.getsource(func).split("=")[0].strip()
111*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
112*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Unable to represent lambda") from e
113*da0073e9SAndroid Build Coastguard Worker    module = _find_module_of_method(func)
114*da0073e9SAndroid Build Coastguard Worker    module = module.replace('torch._ops', 'torch.ops')  # WAR for bug in how torch.ops assigns module
115*da0073e9SAndroid Build Coastguard Worker    # Fixup segment_reduce mismatch
116*da0073e9SAndroid Build Coastguard Worker    if module == "torch" and name == "segment_reduce":
117*da0073e9SAndroid Build Coastguard Worker        name = "_" + name
118*da0073e9SAndroid Build Coastguard Worker    return f'{module}.{name}'
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Workerdef _format_arg(arg: object, max_list_len: float = float('inf')) -> str:
121*da0073e9SAndroid Build Coastguard Worker    if hasattr(arg, '_custom_fx_repr_fn'):
122*da0073e9SAndroid Build Coastguard Worker        return arg._custom_fx_repr_fn()
123*da0073e9SAndroid Build Coastguard Worker    elif isinstance(arg, list):
124*da0073e9SAndroid Build Coastguard Worker        items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
125*da0073e9SAndroid Build Coastguard Worker        maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
126*da0073e9SAndroid Build Coastguard Worker        return f'[{items}{maybe_len}]'
127*da0073e9SAndroid Build Coastguard Worker    elif isinstance(arg, tuple):
128*da0073e9SAndroid Build Coastguard Worker        items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len)
129*da0073e9SAndroid Build Coastguard Worker        maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]'
130*da0073e9SAndroid Build Coastguard Worker        maybe_comma = ',' if len(arg) == 1 else ''
131*da0073e9SAndroid Build Coastguard Worker        return f'({items}{maybe_comma}{maybe_len})'
132*da0073e9SAndroid Build Coastguard Worker    elif isinstance(arg, dict):
133*da0073e9SAndroid Build Coastguard Worker        items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items())
134*da0073e9SAndroid Build Coastguard Worker        return f'{{{items_str}}}'
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    if isinstance(arg, Node):
137*da0073e9SAndroid Build Coastguard Worker        return '%' + str(arg)
138*da0073e9SAndroid Build Coastguard Worker    else:
139*da0073e9SAndroid Build Coastguard Worker        return str(arg)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
142*da0073e9SAndroid Build Coastguard Workerclass Node(_NodeBase):
143*da0073e9SAndroid Build Coastguard Worker    """
144*da0073e9SAndroid Build Coastguard Worker    ``Node`` is the data structure that represents individual operations within
145*da0073e9SAndroid Build Coastguard Worker    a ``Graph``. For the most part, Nodes represent callsites to various entities,
146*da0073e9SAndroid Build Coastguard Worker    such as operators, methods, and Modules (some exceptions include nodes that
147*da0073e9SAndroid Build Coastguard Worker    specify function inputs and outputs). Each ``Node`` has a function specified
148*da0073e9SAndroid Build Coastguard Worker    by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows:
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker    - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
151*da0073e9SAndroid Build Coastguard Worker      ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
152*da0073e9SAndroid Build Coastguard Worker      denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
153*da0073e9SAndroid Build Coastguard Worker      the function parameters (e.g. ``x``) in the graph printout.
154*da0073e9SAndroid Build Coastguard Worker    - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
155*da0073e9SAndroid Build Coastguard Worker      fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
156*da0073e9SAndroid Build Coastguard Worker      ``args`` and ``kwargs`` are don't-care
157*da0073e9SAndroid Build Coastguard Worker    - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
158*da0073e9SAndroid Build Coastguard Worker      to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
159*da0073e9SAndroid Build Coastguard Worker      following the Python calling convention
160*da0073e9SAndroid Build Coastguard Worker    - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
161*da0073e9SAndroid Build Coastguard Worker      as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
162*da0073e9SAndroid Build Coastguard Worker      ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*.
163*da0073e9SAndroid Build Coastguard Worker    - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
164*da0073e9SAndroid Build Coastguard Worker      to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
165*da0073e9SAndroid Build Coastguard Worker      *including the self argument*
166*da0073e9SAndroid Build Coastguard Worker    - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
167*da0073e9SAndroid Build Coastguard Worker      in the Graph printout.
168*da0073e9SAndroid Build Coastguard Worker    """
169*da0073e9SAndroid Build Coastguard Worker    _args: Tuple['Argument', ...]
170*da0073e9SAndroid Build Coastguard Worker    _kwargs: Dict[str, 'Argument']
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
173*da0073e9SAndroid Build Coastguard Worker    def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
174*da0073e9SAndroid Build Coastguard Worker                 args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
175*da0073e9SAndroid Build Coastguard Worker                 return_type : Optional[Any] = None) -> None:
176*da0073e9SAndroid Build Coastguard Worker        """
177*da0073e9SAndroid Build Coastguard Worker        Instantiate an instance of ``Node``. Note: most often, you want to use the
178*da0073e9SAndroid Build Coastguard Worker        Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather
179*da0073e9SAndroid Build Coastguard Worker        than instantiating a ``Node`` directly.
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        Args:
182*da0073e9SAndroid Build Coastguard Worker            graph (Graph): The ``Graph`` to which this ``Node`` should belong.
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker            name (str): The name to which the output of this ``Node`` should be assigned
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker            op (str): The opcode for this ``Node``. Can be one of 'placeholder',
187*da0073e9SAndroid Build Coastguard Worker                'call_method', 'call_module', 'call_function', 'get_attr',
188*da0073e9SAndroid Build Coastguard Worker                'output'
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker            target ('Target'): The target this op should call. See the broader
191*da0073e9SAndroid Build Coastguard Worker                ``Node`` docstring for more details.
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker            args (Tuple['Argument']): The args to be passed to ``target``
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker            kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target``
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker            return_type (Optional[Any]): The python type expression representing the
198*da0073e9SAndroid Build Coastguard Worker                type of the output of this node. This field can be used for
199*da0073e9SAndroid Build Coastguard Worker                annotation of values in the generated code or for other types
200*da0073e9SAndroid Build Coastguard Worker                of analyses.
201*da0073e9SAndroid Build Coastguard Worker        """
202*da0073e9SAndroid Build Coastguard Worker        super().__init__()
203*da0073e9SAndroid Build Coastguard Worker        self.graph = graph
204*da0073e9SAndroid Build Coastguard Worker        self.name = name  # unique name of value being created
205*da0073e9SAndroid Build Coastguard Worker        assert op in _legal_ops
206*da0073e9SAndroid Build Coastguard Worker        self.op = op  # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
207*da0073e9SAndroid Build Coastguard Worker        if op == 'call_function':
208*da0073e9SAndroid Build Coastguard Worker            if not callable(target):
209*da0073e9SAndroid Build Coastguard Worker                raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
210*da0073e9SAndroid Build Coastguard Worker                                 'but a Callable is expected')
211*da0073e9SAndroid Build Coastguard Worker        else:
212*da0073e9SAndroid Build Coastguard Worker            if not isinstance(target, str):
213*da0073e9SAndroid Build Coastguard Worker                raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
214*da0073e9SAndroid Build Coastguard Worker                                 'but a str is expected')
215*da0073e9SAndroid Build Coastguard Worker        self.target = target  # for method/module/function, the name of the method/module/function/attr
216*da0073e9SAndroid Build Coastguard Worker        # being invoked, e.g add, layer1, or torch.add
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        # All `Node`-valued inputs. Key is the Node, value is don't-care.
219*da0073e9SAndroid Build Coastguard Worker        # The public API for this is `all_input_nodes`, this private attribute
220*da0073e9SAndroid Build Coastguard Worker        # should not be accessed directly.
221*da0073e9SAndroid Build Coastguard Worker        self._input_nodes : Dict[Node, None] = {}
222*da0073e9SAndroid Build Coastguard Worker        self.__update_args_kwargs(args, kwargs)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        # All of the nodes that use the value produced by this Node
225*da0073e9SAndroid Build Coastguard Worker        # Note one user may correspond to several uses, e.g. the node fo ``x + x``
226*da0073e9SAndroid Build Coastguard Worker        # would appear once here, but represents two uses.
227*da0073e9SAndroid Build Coastguard Worker        #
228*da0073e9SAndroid Build Coastguard Worker        # Is a dict to act as an "ordered set". Keys are significant, value dont-care
229*da0073e9SAndroid Build Coastguard Worker        self.users : Dict[Node, None] = {}
230*da0073e9SAndroid Build Coastguard Worker        # Type expression representing the output value of this node.
231*da0073e9SAndroid Build Coastguard Worker        # This should contain the same class of Type objects that would appear
232*da0073e9SAndroid Build Coastguard Worker        # as type annotations for function inputs/outputs.
233*da0073e9SAndroid Build Coastguard Worker        #
234*da0073e9SAndroid Build Coastguard Worker        # For placeholder nodes, this value will be used to type-annotate the
235*da0073e9SAndroid Build Coastguard Worker        # generated function parameters.
236*da0073e9SAndroid Build Coastguard Worker        # For the return node, this value will be used to type-annotate the
237*da0073e9SAndroid Build Coastguard Worker        # generated function return type. (Note this is a special case. ``return``
238*da0073e9SAndroid Build Coastguard Worker        # does not produce a value, it's more of a notation. Thus, this value
239*da0073e9SAndroid Build Coastguard Worker        # describes the type of args[0] in the ``return`` node.
240*da0073e9SAndroid Build Coastguard Worker        self.type : Optional[Any] = return_type
241*da0073e9SAndroid Build Coastguard Worker        self._sort_key: Any = ()
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker        # If set, use this fn to print this node
244*da0073e9SAndroid Build Coastguard Worker        self._repr_fn : Optional[Callable[[Node], str]] = None
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        # Dictionary to store metadata passes need to do their
247*da0073e9SAndroid Build Coastguard Worker        # transformations. This metadata is preserved across node copies
248*da0073e9SAndroid Build Coastguard Worker        self.meta : Dict[str, Any] = {}
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker    def __getstate__(self) -> Dict[str, Any]:
251*da0073e9SAndroid Build Coastguard Worker        state = self.__dict__.copy()
252*da0073e9SAndroid Build Coastguard Worker        state["_erased"] = self._erased
253*da0073e9SAndroid Build Coastguard Worker        state["_prev"] = self._prev
254*da0073e9SAndroid Build Coastguard Worker        state["_next"] = self._next
255*da0073e9SAndroid Build Coastguard Worker        return state
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state: Dict[str, Any]) -> None:
258*da0073e9SAndroid Build Coastguard Worker        _erased = state.pop("_erased")
259*da0073e9SAndroid Build Coastguard Worker        _prev = state.pop("_prev")
260*da0073e9SAndroid Build Coastguard Worker        _next = state.pop("_next")
261*da0073e9SAndroid Build Coastguard Worker        self.__dict__.update(state)
262*da0073e9SAndroid Build Coastguard Worker        self._erased = _erased
263*da0073e9SAndroid Build Coastguard Worker        self._prev = _prev
264*da0073e9SAndroid Build Coastguard Worker        self._next = _next
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    @property
267*da0073e9SAndroid Build Coastguard Worker    def next(self) -> 'Node':
268*da0073e9SAndroid Build Coastguard Worker        """
269*da0073e9SAndroid Build Coastguard Worker        Returns the next ``Node`` in the linked list of Nodes.
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        Returns:
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker            The next ``Node`` in the linked list of Nodes.
274*da0073e9SAndroid Build Coastguard Worker        """
275*da0073e9SAndroid Build Coastguard Worker        return self._next
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    @property
278*da0073e9SAndroid Build Coastguard Worker    def prev(self) -> 'Node':
279*da0073e9SAndroid Build Coastguard Worker        """
280*da0073e9SAndroid Build Coastguard Worker        Returns the previous ``Node`` in the linked list of Nodes.
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        Returns:
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker            The previous ``Node`` in the linked list of Nodes.
285*da0073e9SAndroid Build Coastguard Worker        """
286*da0073e9SAndroid Build Coastguard Worker        return self._prev
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
289*da0073e9SAndroid Build Coastguard Worker    def prepend(self, x: 'Node') -> None:
290*da0073e9SAndroid Build Coastguard Worker        """
291*da0073e9SAndroid Build Coastguard Worker        Insert x before this node in the list of nodes in the graph. Example::
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker            Before: p -> self
294*da0073e9SAndroid Build Coastguard Worker                    bx -> x -> ax
295*da0073e9SAndroid Build Coastguard Worker            After:  p -> x -> self
296*da0073e9SAndroid Build Coastguard Worker                    bx -> ax
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker        Args:
299*da0073e9SAndroid Build Coastguard Worker            x (Node): The node to put before this node. Must be a member of the same graph.
300*da0073e9SAndroid Build Coastguard Worker        """
301*da0073e9SAndroid Build Coastguard Worker        assert self.graph == x.graph, "Attempting to move a Node into a different Graph"
302*da0073e9SAndroid Build Coastguard Worker        if self == x:
303*da0073e9SAndroid Build Coastguard Worker            warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.")
304*da0073e9SAndroid Build Coastguard Worker            return
305*da0073e9SAndroid Build Coastguard Worker        x._remove_from_list()
306*da0073e9SAndroid Build Coastguard Worker        p = self._prev
307*da0073e9SAndroid Build Coastguard Worker        p._next, x._prev = x, p
308*da0073e9SAndroid Build Coastguard Worker        x._next, self._prev = self, x
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        # compute x._sort_key
311*da0073e9SAndroid Build Coastguard Worker        psk = x._prev._sort_key
312*da0073e9SAndroid Build Coastguard Worker        nsk = x._next._sort_key
313*da0073e9SAndroid Build Coastguard Worker        if len(psk) > len(nsk):
314*da0073e9SAndroid Build Coastguard Worker            idx: int
315*da0073e9SAndroid Build Coastguard Worker            *prefix, idx = psk[:len(nsk) + 1]
316*da0073e9SAndroid Build Coastguard Worker            x._sort_key = (*prefix, idx + 1)
317*da0073e9SAndroid Build Coastguard Worker        elif len(psk) < len(nsk):
318*da0073e9SAndroid Build Coastguard Worker            *prefix, idx = nsk[:len(psk) + 1]
319*da0073e9SAndroid Build Coastguard Worker            x._sort_key = (*prefix, idx - 1)
320*da0073e9SAndroid Build Coastguard Worker        else:  # same length, increase length by 1
321*da0073e9SAndroid Build Coastguard Worker            x._sort_key = (*psk, 0)
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker    def __gt__(self, other: 'Node') -> bool:
324*da0073e9SAndroid Build Coastguard Worker        return self._sort_key > other._sort_key
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker    def __lt__(self, other: 'Node') -> bool:
327*da0073e9SAndroid Build Coastguard Worker        return self._sort_key < other._sort_key
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    def __ge__(self, other: 'Node') -> bool:
330*da0073e9SAndroid Build Coastguard Worker        return self > other or self == other
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    def __le__(self, other: 'Node') -> bool:
333*da0073e9SAndroid Build Coastguard Worker        return self < other or self == other
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
336*da0073e9SAndroid Build Coastguard Worker    def append(self, x: 'Node') -> None:
337*da0073e9SAndroid Build Coastguard Worker        """
338*da0073e9SAndroid Build Coastguard Worker        Insert ``x`` after this node in the list of nodes in the graph.
339*da0073e9SAndroid Build Coastguard Worker        Equivalent to ``self.next.prepend(x)``
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker        Args:
342*da0073e9SAndroid Build Coastguard Worker            x (Node): The node to put after this node. Must be a member of the same graph.
343*da0073e9SAndroid Build Coastguard Worker        """
344*da0073e9SAndroid Build Coastguard Worker        self._next.prepend(x)
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    def _remove_from_list(self) -> None:
347*da0073e9SAndroid Build Coastguard Worker        p, n = self._prev, self._next
348*da0073e9SAndroid Build Coastguard Worker        p._next, n._prev = n, p
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    @property
351*da0073e9SAndroid Build Coastguard Worker    def args(self) -> Tuple[Argument, ...]:
352*da0073e9SAndroid Build Coastguard Worker        """
353*da0073e9SAndroid Build Coastguard Worker        The tuple of arguments to this ``Node``. The interpretation of arguments
354*da0073e9SAndroid Build Coastguard Worker        depends on the node's opcode. See the :class:`Node` docstring for more
355*da0073e9SAndroid Build Coastguard Worker        information.
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        Assignment to this property is allowed. All accounting of uses and users
358*da0073e9SAndroid Build Coastguard Worker        is updated automatically on assignment.
359*da0073e9SAndroid Build Coastguard Worker        """
360*da0073e9SAndroid Build Coastguard Worker        return self._args
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker    @args.setter
363*da0073e9SAndroid Build Coastguard Worker    def args(self, a : Tuple[Argument, ...]) -> None:
364*da0073e9SAndroid Build Coastguard Worker        """
365*da0073e9SAndroid Build Coastguard Worker        Set the tuple of arguments to this Node. The interpretation of arguments
366*da0073e9SAndroid Build Coastguard Worker        depends on the node's opcode. See the ``fx.Graph`` docstring for more
367*da0073e9SAndroid Build Coastguard Worker        information.
368*da0073e9SAndroid Build Coastguard Worker        """
369*da0073e9SAndroid Build Coastguard Worker        # DO NOT CALL `__update_args_kwargs` directly. The correct way to
370*da0073e9SAndroid Build Coastguard Worker        # set `args` is via direct assignment, i.e. `node.args = new_args`
371*da0073e9SAndroid Build Coastguard Worker        self.__update_args_kwargs(a, self._kwargs)
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    @property
374*da0073e9SAndroid Build Coastguard Worker    def kwargs(self) -> Dict[str, Argument]:
375*da0073e9SAndroid Build Coastguard Worker        """
376*da0073e9SAndroid Build Coastguard Worker        The dict of keyword arguments to this ``Node``. The interpretation of arguments
377*da0073e9SAndroid Build Coastguard Worker        depends on the node's opcode. See the :class:`Node` docstring for more
378*da0073e9SAndroid Build Coastguard Worker        information.
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        Assignment to this property is allowed. All accounting of uses and users
381*da0073e9SAndroid Build Coastguard Worker        is updated automatically on assignment.
382*da0073e9SAndroid Build Coastguard Worker        """
383*da0073e9SAndroid Build Coastguard Worker        return self._kwargs
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker    @kwargs.setter
386*da0073e9SAndroid Build Coastguard Worker    def kwargs(self, k : Dict[str, Argument]) -> None:
387*da0073e9SAndroid Build Coastguard Worker        """
388*da0073e9SAndroid Build Coastguard Worker        Set the dict of kwargs to this Node. The interpretation of arguments
389*da0073e9SAndroid Build Coastguard Worker        depends on the node's opcode. See the ``fx.Graph`` docstring for more
390*da0073e9SAndroid Build Coastguard Worker        information.
391*da0073e9SAndroid Build Coastguard Worker        """
392*da0073e9SAndroid Build Coastguard Worker        # DO NOT CALL `__update_args_kwargs` directly. The correct way to
393*da0073e9SAndroid Build Coastguard Worker        # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs`
394*da0073e9SAndroid Build Coastguard Worker        self.__update_args_kwargs(self._args, k)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    @property
397*da0073e9SAndroid Build Coastguard Worker    def all_input_nodes(self) -> List['Node']:
398*da0073e9SAndroid Build Coastguard Worker        """
399*da0073e9SAndroid Build Coastguard Worker        Return all Nodes that are inputs to this Node. This is equivalent to
400*da0073e9SAndroid Build Coastguard Worker        iterating over ``args`` and ``kwargs`` and only collecting the values that
401*da0073e9SAndroid Build Coastguard Worker        are Nodes.
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker        Returns:
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker            List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
406*da0073e9SAndroid Build Coastguard Worker            ``Node``, in that order.
407*da0073e9SAndroid Build Coastguard Worker        """
408*da0073e9SAndroid Build Coastguard Worker        return list(self._input_nodes.keys())
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
411*da0073e9SAndroid Build Coastguard Worker    def update_arg(self, idx : int, arg : Argument) -> None:
412*da0073e9SAndroid Build Coastguard Worker        """
413*da0073e9SAndroid Build Coastguard Worker        Update an existing positional argument to contain the new value
414*da0073e9SAndroid Build Coastguard Worker        ``arg``. After calling, ``self.args[idx] == arg``.
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        Args:
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker            idx (int): The index into ``self.args`` of the element to update
419*da0073e9SAndroid Build Coastguard Worker            arg (Argument): The new argument value to write into ``args``
420*da0073e9SAndroid Build Coastguard Worker        """
421*da0073e9SAndroid Build Coastguard Worker        args = list(self.args)
422*da0073e9SAndroid Build Coastguard Worker        args[idx] = arg
423*da0073e9SAndroid Build Coastguard Worker        self.args = tuple(args)
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
426*da0073e9SAndroid Build Coastguard Worker    def insert_arg(self, idx : int, arg : Argument) -> None:
427*da0073e9SAndroid Build Coastguard Worker        """
428*da0073e9SAndroid Build Coastguard Worker        Insert an positional argument to the argument list with given index.
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        Args:
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker            idx (int): The index of the element in ``self.args`` to be inserted before.
433*da0073e9SAndroid Build Coastguard Worker            arg (Argument): The new argument value to insert into ``args``
434*da0073e9SAndroid Build Coastguard Worker        """
435*da0073e9SAndroid Build Coastguard Worker        assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)"
436*da0073e9SAndroid Build Coastguard Worker        args_left = self.args[:idx]
437*da0073e9SAndroid Build Coastguard Worker        args_right = self.args[idx:]
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker        self._args = args_left + (arg,) + args_right
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        _new_input_nodes: Dict[Node, None] = {}
442*da0073e9SAndroid Build Coastguard Worker        map_arg(arg, _new_input_nodes.setdefault)
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        for new_use in _new_input_nodes.keys():
445*da0073e9SAndroid Build Coastguard Worker            if new_use not in self._input_nodes:
446*da0073e9SAndroid Build Coastguard Worker                self._input_nodes.setdefault(new_use)
447*da0073e9SAndroid Build Coastguard Worker                new_use.users.setdefault(self)
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
450*da0073e9SAndroid Build Coastguard Worker    def update_kwarg(self, key : str, arg : Argument) -> None:
451*da0073e9SAndroid Build Coastguard Worker        """
452*da0073e9SAndroid Build Coastguard Worker        Update an existing keyword argument to contain the new value
453*da0073e9SAndroid Build Coastguard Worker        ``arg``. After calling, ``self.kwargs[key] == arg``.
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker        Args:
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker            key (str): The key in ``self.kwargs`` of the element to update
458*da0073e9SAndroid Build Coastguard Worker            arg (Argument): The new argument value to write into ``kwargs``
459*da0073e9SAndroid Build Coastguard Worker        """
460*da0073e9SAndroid Build Coastguard Worker        self.kwargs = {**self.kwargs, key: arg}
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker    @property
463*da0073e9SAndroid Build Coastguard Worker    def stack_trace(self) -> Optional[str]:
464*da0073e9SAndroid Build Coastguard Worker        """
465*da0073e9SAndroid Build Coastguard Worker        Return the Python stack trace that was recorded during tracing, if any.
466*da0073e9SAndroid Build Coastguard Worker        When traced with fx.Tracer, this property is usually populated by
467*da0073e9SAndroid Build Coastguard Worker        `Tracer.create_proxy`. To record stack traces during tracing for debug purposes,
468*da0073e9SAndroid Build Coastguard Worker        set `record_stack_traces = True` on the `Tracer` instance.
469*da0073e9SAndroid Build Coastguard Worker        When traced with dynamo, this property will be populated by default by
470*da0073e9SAndroid Build Coastguard Worker        `OutputGraph.create_proxy`.
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        stack_trace would have the innermost frame at the end of the string.
473*da0073e9SAndroid Build Coastguard Worker        """
474*da0073e9SAndroid Build Coastguard Worker        return self.meta.get("stack_trace", None)
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker    @stack_trace.setter
477*da0073e9SAndroid Build Coastguard Worker    def stack_trace(self, trace : Optional[str]) -> None:
478*da0073e9SAndroid Build Coastguard Worker        self.meta["stack_trace"] = trace
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None:
481*da0073e9SAndroid Build Coastguard Worker        """
482*da0073e9SAndroid Build Coastguard Worker        This API is internal. Do *not* call it directly.
483*da0073e9SAndroid Build Coastguard Worker        """
484*da0073e9SAndroid Build Coastguard Worker        def update_users_and_input_nodes(n: Any) -> Any:
485*da0073e9SAndroid Build Coastguard Worker            if isinstance(n, Node):
486*da0073e9SAndroid Build Coastguard Worker                self._input_nodes.setdefault(n)
487*da0073e9SAndroid Build Coastguard Worker                n.users.setdefault(self)
488*da0073e9SAndroid Build Coastguard Worker            return n
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker        # Clear prior users and input_nodes
491*da0073e9SAndroid Build Coastguard Worker        for old_use in self._input_nodes.keys():
492*da0073e9SAndroid Build Coastguard Worker            old_use.users.pop(self)
493*da0073e9SAndroid Build Coastguard Worker        self._input_nodes = {}
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker        # We do three things in a single pass of the args
496*da0073e9SAndroid Build Coastguard Worker        # - Normalize list->immutable_list, dict->immutable_dict, etc
497*da0073e9SAndroid Build Coastguard Worker        # - Populate self._input_nodes
498*da0073e9SAndroid Build Coastguard Worker        # - Populate arg.users[self] for each arg
499*da0073e9SAndroid Build Coastguard Worker        self._args = map_aggregate(new_args, update_users_and_input_nodes)  # type: ignore[assignment]
500*da0073e9SAndroid Build Coastguard Worker        self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes)  # type: ignore[assignment]
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:
503*da0073e9SAndroid Build Coastguard Worker        if self._repr_fn:
504*da0073e9SAndroid Build Coastguard Worker            return self._repr_fn(self)
505*da0073e9SAndroid Build Coastguard Worker        return self.name
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker    def _pretty_print_target(self, target: object) -> str:
508*da0073e9SAndroid Build Coastguard Worker        """
509*da0073e9SAndroid Build Coastguard Worker        Make target printouts more user-friendly.
510*da0073e9SAndroid Build Coastguard Worker        1) builtins will be printed as `builtins.xyz`
511*da0073e9SAndroid Build Coastguard Worker        2) operators will be printed as `operator.xyz`
512*da0073e9SAndroid Build Coastguard Worker        3) other callables will be printed with qualified name, e.g. torch.add
513*da0073e9SAndroid Build Coastguard Worker        """
514*da0073e9SAndroid Build Coastguard Worker        if isinstance(target, str):
515*da0073e9SAndroid Build Coastguard Worker            return target
516*da0073e9SAndroid Build Coastguard Worker        if hasattr(target, '__module__'):
517*da0073e9SAndroid Build Coastguard Worker            name = getattr(target, '__name__', None)
518*da0073e9SAndroid Build Coastguard Worker            if name is None:
519*da0073e9SAndroid Build Coastguard Worker                # Just to be defensive, if we don't have `__name__`, get the
520*da0073e9SAndroid Build Coastguard Worker                # qualname. Not sure if this happens for any members of `operator`
521*da0073e9SAndroid Build Coastguard Worker                # or `builtins`. This fallback path is not as good, since e.g.
522*da0073e9SAndroid Build Coastguard Worker                # things in `operator` have `_operator` as their __module__.
523*da0073e9SAndroid Build Coastguard Worker                # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__`
524*da0073e9SAndroid Build Coastguard Worker                return _get_qualified_name(target)  # type: ignore[arg-type]
525*da0073e9SAndroid Build Coastguard Worker            if target.__module__ == 'builtins':
526*da0073e9SAndroid Build Coastguard Worker                return f'builtins.{name}'
527*da0073e9SAndroid Build Coastguard Worker            elif target.__module__ == '_operator':
528*da0073e9SAndroid Build Coastguard Worker                return f'operator.{name}'
529*da0073e9SAndroid Build Coastguard Worker        return _get_qualified_name(target)  # type: ignore[arg-type]
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
532*da0073e9SAndroid Build Coastguard Worker    def format_node(self,
533*da0073e9SAndroid Build Coastguard Worker                    placeholder_names: Optional[List[str]] = None,
534*da0073e9SAndroid Build Coastguard Worker                    maybe_return_typename: Optional[List[str]] = None) -> Optional[str]:
535*da0073e9SAndroid Build Coastguard Worker        """
536*da0073e9SAndroid Build Coastguard Worker        Return a descriptive string representation of ``self``.
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        This method can be used with no arguments as a debugging
539*da0073e9SAndroid Build Coastguard Worker        utility.
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker        This function is also used internally in the ``__str__`` method
542*da0073e9SAndroid Build Coastguard Worker        of ``Graph``. Together, the strings in ``placeholder_names``
543*da0073e9SAndroid Build Coastguard Worker        and ``maybe_return_typename`` make up the signature of the
544*da0073e9SAndroid Build Coastguard Worker        autogenerated ``forward`` function in this Graph's surrounding
545*da0073e9SAndroid Build Coastguard Worker        GraphModule. ``placeholder_names`` and ``maybe_return_typename``
546*da0073e9SAndroid Build Coastguard Worker        should not be used otherwise.
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker        Args:
549*da0073e9SAndroid Build Coastguard Worker            placeholder_names: A list that will store formatted strings
550*da0073e9SAndroid Build Coastguard Worker                representing the placeholders in the generated
551*da0073e9SAndroid Build Coastguard Worker                ``forward`` function. Internal use only.
552*da0073e9SAndroid Build Coastguard Worker            maybe_return_typename: A single-element list that will store
553*da0073e9SAndroid Build Coastguard Worker                a formatted string representing the output of the
554*da0073e9SAndroid Build Coastguard Worker                generated ``forward`` function. Internal use only.
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker        Returns:
557*da0073e9SAndroid Build Coastguard Worker            str: If 1) we're using ``format_node`` as an internal helper
558*da0073e9SAndroid Build Coastguard Worker                in the ``__str__`` method of ``Graph``, and 2) ``self``
559*da0073e9SAndroid Build Coastguard Worker                is a placeholder Node, return ``None``. Otherwise,
560*da0073e9SAndroid Build Coastguard Worker                return a  descriptive string representation of the
561*da0073e9SAndroid Build Coastguard Worker                current Node.
562*da0073e9SAndroid Build Coastguard Worker        """
563*da0073e9SAndroid Build Coastguard Worker        if self.op == 'placeholder':
564*da0073e9SAndroid Build Coastguard Worker            assert isinstance(self.target, str)
565*da0073e9SAndroid Build Coastguard Worker            arg_str = self.target
566*da0073e9SAndroid Build Coastguard Worker            arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else ''
567*da0073e9SAndroid Build Coastguard Worker            if placeholder_names:
568*da0073e9SAndroid Build Coastguard Worker                placeholder_names.append(arg_str)
569*da0073e9SAndroid Build Coastguard Worker                return None
570*da0073e9SAndroid Build Coastguard Worker            maybe_typename = f'{_type_repr(self.type)} ' if self.type else ''
571*da0073e9SAndroid Build Coastguard Worker            default_val = '(default=' + str(self.args[0]) + ')' if self.args else ''
572*da0073e9SAndroid Build Coastguard Worker            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}'
573*da0073e9SAndroid Build Coastguard Worker        elif self.op == 'get_attr':
574*da0073e9SAndroid Build Coastguard Worker            maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
575*da0073e9SAndroid Build Coastguard Worker            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
576*da0073e9SAndroid Build Coastguard Worker                   f'{self.op}[target={self._pretty_print_target(self.target)}]'
577*da0073e9SAndroid Build Coastguard Worker        elif self.op == 'output':
578*da0073e9SAndroid Build Coastguard Worker            if self.type and maybe_return_typename:
579*da0073e9SAndroid Build Coastguard Worker                maybe_return_typename[0] = f' -> {_type_repr(self.type)}'
580*da0073e9SAndroid Build Coastguard Worker            return f'return {self.args[0]}'
581*da0073e9SAndroid Build Coastguard Worker        else:
582*da0073e9SAndroid Build Coastguard Worker            maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
583*da0073e9SAndroid Build Coastguard Worker            return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
584*da0073e9SAndroid Build Coastguard Worker                   f'{self.op}[target={self._pretty_print_target(self.target)}](' \
585*da0073e9SAndroid Build Coastguard Worker                   f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
588*da0073e9SAndroid Build Coastguard Worker    def replace_all_uses_with(self,
589*da0073e9SAndroid Build Coastguard Worker                              replace_with: 'Node',
590*da0073e9SAndroid Build Coastguard Worker                              delete_user_cb: Callable[['Node'], bool] = lambda user: True,
591*da0073e9SAndroid Build Coastguard Worker                              *,
592*da0073e9SAndroid Build Coastguard Worker                              propagate_meta: bool = False
593*da0073e9SAndroid Build Coastguard Worker                              ) -> List['Node']:
594*da0073e9SAndroid Build Coastguard Worker        """
595*da0073e9SAndroid Build Coastguard Worker        Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker        Args:
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker            replace_with (Node): The node to replace all uses of ``self`` with.
600*da0073e9SAndroid Build Coastguard Worker            delete_user_cb (Callable): Callback that is called to determine
601*da0073e9SAndroid Build Coastguard Worker              whether a given user of the self node should be removed.
602*da0073e9SAndroid Build Coastguard Worker            propagate_meta (bool): Whether or not to copy all properties
603*da0073e9SAndroid Build Coastguard Worker              on the .meta field of the original node onto the replacement node.
604*da0073e9SAndroid Build Coastguard Worker              For safety, this is only valid to do if the replacement node
605*da0073e9SAndroid Build Coastguard Worker              doesn't already have an existing .meta field.
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        Returns:
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker            The list of Nodes on which this change was made.
610*da0073e9SAndroid Build Coastguard Worker        """
611*da0073e9SAndroid Build Coastguard Worker        if propagate_meta:
612*da0073e9SAndroid Build Coastguard Worker            assert len(replace_with.meta) == 0, \
613*da0073e9SAndroid Build Coastguard Worker                'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \
614*da0073e9SAndroid Build Coastguard Worker                'but replace_with already has .meta keys'
615*da0073e9SAndroid Build Coastguard Worker            for k, v in self.meta.items():
616*da0073e9SAndroid Build Coastguard Worker                replace_with.meta[k] = v
617*da0073e9SAndroid Build Coastguard Worker        to_process = list(self.users)
618*da0073e9SAndroid Build Coastguard Worker        skipped = []
619*da0073e9SAndroid Build Coastguard Worker        m = self.graph.owning_module
620*da0073e9SAndroid Build Coastguard Worker        for use_node in to_process:
621*da0073e9SAndroid Build Coastguard Worker            if not delete_user_cb(use_node):
622*da0073e9SAndroid Build Coastguard Worker                skipped.append(use_node)
623*da0073e9SAndroid Build Coastguard Worker                continue
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker            def maybe_replace_node(n : Node) -> Node:
626*da0073e9SAndroid Build Coastguard Worker                if n == self:
627*da0073e9SAndroid Build Coastguard Worker                    return replace_with
628*da0073e9SAndroid Build Coastguard Worker                else:
629*da0073e9SAndroid Build Coastguard Worker                    return n
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker            if getattr(m, "_replace_hook", None):
632*da0073e9SAndroid Build Coastguard Worker                m._replace_hook(old=self, new=replace_with.name, user=use_node)
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker            new_args = map_arg(use_node.args, maybe_replace_node)
635*da0073e9SAndroid Build Coastguard Worker            new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
636*da0073e9SAndroid Build Coastguard Worker            assert isinstance(new_args, tuple)
637*da0073e9SAndroid Build Coastguard Worker            assert isinstance(new_kwargs, dict)
638*da0073e9SAndroid Build Coastguard Worker            use_node.__update_args_kwargs(new_args, new_kwargs)
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker        assert len(self.users) - len(skipped) == 0
641*da0073e9SAndroid Build Coastguard Worker        return [n for n in to_process if n not in skipped]
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=False)
644*da0073e9SAndroid Build Coastguard Worker    def is_impure(self) -> bool:
645*da0073e9SAndroid Build Coastguard Worker        """
646*da0073e9SAndroid Build Coastguard Worker        Returns whether this op is impure, i.e. if its op is a placeholder or
647*da0073e9SAndroid Build Coastguard Worker        output, or if a call_function or call_module which is impure.
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker        Returns:
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker            bool: If the op is impure or not.
652*da0073e9SAndroid Build Coastguard Worker        """
653*da0073e9SAndroid Build Coastguard Worker        if self.op in {"placeholder", "output"}:
654*da0073e9SAndroid Build Coastguard Worker            return True
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker        # Check if an impure function based on schema.
657*da0073e9SAndroid Build Coastguard Worker        if self.op == "call_function":
658*da0073e9SAndroid Build Coastguard Worker            schema = getattr(self.target, "_schema", None)
659*da0073e9SAndroid Build Coastguard Worker            schema_mutable = schema is not None and schema.is_mutable
660*da0073e9SAndroid Build Coastguard Worker            return schema_mutable or self.target in _side_effectful_functions
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker        # Check if an impure module.
663*da0073e9SAndroid Build Coastguard Worker        if self.op == "call_module":
664*da0073e9SAndroid Build Coastguard Worker            assert (
665*da0073e9SAndroid Build Coastguard Worker                self.graph.owning_module is not None
666*da0073e9SAndroid Build Coastguard Worker            ), "self.graph.owning_module not set for purity check"
667*da0073e9SAndroid Build Coastguard Worker            target_mod = self.graph.owning_module.get_submodule(self.target)
668*da0073e9SAndroid Build Coastguard Worker            assert (
669*da0073e9SAndroid Build Coastguard Worker                target_mod is not None
670*da0073e9SAndroid Build Coastguard Worker            ), f"Did not find expected submodule target {self.target}"
671*da0073e9SAndroid Build Coastguard Worker            return getattr(target_mod, "_is_impure", False)
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker        return False
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=False)
676*da0073e9SAndroid Build Coastguard Worker    def normalized_arguments(
677*da0073e9SAndroid Build Coastguard Worker            self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
678*da0073e9SAndroid Build Coastguard Worker            kwarg_types : Optional[Dict[str, Any]] = None,
679*da0073e9SAndroid Build Coastguard Worker            normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
680*da0073e9SAndroid Build Coastguard Worker        """
681*da0073e9SAndroid Build Coastguard Worker        Returns normalized arguments to Python targets. This means that
682*da0073e9SAndroid Build Coastguard Worker        `args/kwargs` will be matched up to the module/functional's
683*da0073e9SAndroid Build Coastguard Worker        signature and return exclusively kwargs in positional order
684*da0073e9SAndroid Build Coastguard Worker        if `normalize_to_only_use_kwargs` is true.
685*da0073e9SAndroid Build Coastguard Worker        Also populates default values. Does not support positional-only
686*da0073e9SAndroid Build Coastguard Worker        parameters or varargs parameters.
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker        Supports module calls.
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker        May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker        Args:
693*da0073e9SAndroid Build Coastguard Worker            root (torch.nn.Module): Module upon which to resolve module targets.
694*da0073e9SAndroid Build Coastguard Worker            arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
695*da0073e9SAndroid Build Coastguard Worker            kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
696*da0073e9SAndroid Build Coastguard Worker            normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker        Returns:
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker            Returns NamedTuple ArgsKwargsPair, or `None` if not successful.
701*da0073e9SAndroid Build Coastguard Worker        """
702*da0073e9SAndroid Build Coastguard Worker        if self.op == 'call_function':
703*da0073e9SAndroid Build Coastguard Worker            assert callable(self.target)
704*da0073e9SAndroid Build Coastguard Worker            return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types)  # type: ignore[arg-type]
705*da0073e9SAndroid Build Coastguard Worker        elif self.op == 'call_module':
706*da0073e9SAndroid Build Coastguard Worker            assert isinstance(self.target, str)
707*da0073e9SAndroid Build Coastguard Worker            return normalize_module(root, self.target, self.args, self.kwargs)  # type: ignore[arg-type]
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Worker        return None
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker    @compatibility(is_backward_compatible=True)
712*da0073e9SAndroid Build Coastguard Worker    def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None:
713*da0073e9SAndroid Build Coastguard Worker        """
714*da0073e9SAndroid Build Coastguard Worker        Loop through input nodes of ``self``, and replace all instances of
715*da0073e9SAndroid Build Coastguard Worker        ``old_input`` with ``new_input``.
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker        Args:
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker            old_input (Node): The old input node to be replaced.
720*da0073e9SAndroid Build Coastguard Worker            new_input (Node): The new input node to replace ``old_input``.
721*da0073e9SAndroid Build Coastguard Worker        """
722*da0073e9SAndroid Build Coastguard Worker        def maybe_replace_node(n : Node) -> Node:
723*da0073e9SAndroid Build Coastguard Worker            return new_input if n == old_input else n
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker        m = self.graph.owning_module
726*da0073e9SAndroid Build Coastguard Worker        if getattr(m, "_replace_hook", None):
727*da0073e9SAndroid Build Coastguard Worker            m._replace_hook(old=old_input, new=new_input.name, user=self)
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker        new_args = map_arg(self.args, maybe_replace_node)
730*da0073e9SAndroid Build Coastguard Worker        new_kwargs = map_arg(self.kwargs, maybe_replace_node)
731*da0073e9SAndroid Build Coastguard Worker        assert isinstance(new_args, tuple)
732*da0073e9SAndroid Build Coastguard Worker        assert isinstance(new_kwargs, dict)
733*da0073e9SAndroid Build Coastguard Worker        self.__update_args_kwargs(new_args, new_kwargs)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    def _rename(self, candidate: str) -> None:
736*da0073e9SAndroid Build Coastguard Worker        if candidate == self.name:
737*da0073e9SAndroid Build Coastguard Worker            return
738*da0073e9SAndroid Build Coastguard Worker        name = self.graph._graph_namespace.create_name(candidate, None)
739*da0073e9SAndroid Build Coastguard Worker        self.name = name
740*da0073e9SAndroid Build Coastguard Worker        self.graph._graph_namespace._rename_object(self, name)
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    def __setattr__(self, name: str, value: Any) -> None:
743*da0073e9SAndroid Build Coastguard Worker        if name == 'name' and hasattr(self, "name"):
744*da0073e9SAndroid Build Coastguard Worker            m = self.graph.owning_module
745*da0073e9SAndroid Build Coastguard Worker            if getattr(m, "_replace_hook", None):
746*da0073e9SAndroid Build Coastguard Worker                assert isinstance(value, str)
747*da0073e9SAndroid Build Coastguard Worker                for user in self.users:
748*da0073e9SAndroid Build Coastguard Worker                    m._replace_hook(old=self, new=value, user=user)
749*da0073e9SAndroid Build Coastguard Worker        update = False
750*da0073e9SAndroid Build Coastguard Worker        if (
751*da0073e9SAndroid Build Coastguard Worker                hasattr(self, name) and
752*da0073e9SAndroid Build Coastguard Worker                hasattr(self.graph, "_find_nodes_lookup_table") and
753*da0073e9SAndroid Build Coastguard Worker                self in self.graph._find_nodes_lookup_table
754*da0073e9SAndroid Build Coastguard Worker        ):
755*da0073e9SAndroid Build Coastguard Worker            update = True
756*da0073e9SAndroid Build Coastguard Worker            self.graph._find_nodes_lookup_table.remove(self)
757*da0073e9SAndroid Build Coastguard Worker        object.__setattr__(self, name, value)
758*da0073e9SAndroid Build Coastguard Worker        if update:
759*da0073e9SAndroid Build Coastguard Worker            self.graph._find_nodes_lookup_table.insert(self)
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
762*da0073e9SAndroid Build Coastguard Workerdef map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
763*da0073e9SAndroid Build Coastguard Worker    """
764*da0073e9SAndroid Build Coastguard Worker    Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
765*da0073e9SAndroid Build Coastguard Worker    """
766*da0073e9SAndroid Build Coastguard Worker    assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"
767*da0073e9SAndroid Build Coastguard Worker    return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=True)
770*da0073e9SAndroid Build Coastguard Workerdef map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument:
771*da0073e9SAndroid Build Coastguard Worker    """
772*da0073e9SAndroid Build Coastguard Worker    Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
773*da0073e9SAndroid Build Coastguard Worker    """
774*da0073e9SAndroid Build Coastguard Worker    if isinstance(a, tuple):
775*da0073e9SAndroid Build Coastguard Worker        t = tuple([map_aggregate(elem, fn) for elem in a])
776*da0073e9SAndroid Build Coastguard Worker        # Support NamedTuple (if it has `_fields`) by repacking into original type.
777*da0073e9SAndroid Build Coastguard Worker        return t if not hasattr(a, '_fields') else type(a)(*t)  # type: ignore[arg-type]
778*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, list):
779*da0073e9SAndroid Build Coastguard Worker        return immutable_list([map_aggregate(elem, fn) for elem in a])
780*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, dict):
781*da0073e9SAndroid Build Coastguard Worker        rv = immutable_dict()
782*da0073e9SAndroid Build Coastguard Worker        for k, v in a.items():
783*da0073e9SAndroid Build Coastguard Worker            dict.__setitem__(rv, k, map_aggregate(v, fn))
784*da0073e9SAndroid Build Coastguard Worker        return rv
785*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, slice):
786*da0073e9SAndroid Build Coastguard Worker        return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn))
787*da0073e9SAndroid Build Coastguard Worker    else:
788*da0073e9SAndroid Build Coastguard Worker        return fn(a)
789